iofu728 commited on
Commit
27e09a4
1 Parent(s): 9451f4b

Feature(MInference): add local mode

Browse files
Files changed (2) hide show
  1. app.py +90 -97
  2. requirements.txt +1 -2
app.py CHANGED
@@ -1,53 +1,47 @@
1
  import subprocess
2
  import os
3
- # Install flash attention, skipping CUDA build if necessary
4
-
5
- # os.environ["CPATH"] = "$CPATH:/usr/local/cuda/include"
6
- # os.environ["CUDA_HOME"] = "/usr/local/cuda"
7
- # os.environ["PATH"] = "$PATH:$CUDA_HOME/bin"
8
- # os.environ["LIBRARY_PATH"] = "$LIBRARY_PATH:/usr/local/cuda/lib64"
9
-
10
- subprocess.run(
11
- "/home/user/.pyenv/shims/pip install flash-attn --no-build-isolation",
12
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
13
- shell=True,
14
- )
15
- # subprocess.run(
16
- # "/home/user/.pyenv/shims/pip install pycuda==2023.1",
17
- # env={"CPATH": "$CPATH:/usr/local/cuda/include", "LIBRARY_PATH": "$LIBRARY_PATH:/usr/local/cuda/lib64"},
18
- # shell=True,
19
- # )
20
 
21
  import gradio as gr
22
  import os
23
  import spaces
24
- from transformers import AutoModelForCausalLM
25
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
26
  from threading import Thread
 
27
 
28
  # Set an environment variable
29
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
30
 
31
 
32
- DESCRIPTION = '''
33
- <div>
34
- <h1 style="text-align: center;">Meta Llama3 8B</h1>
35
- <p>This Space demonstrates the instruction-tuned model <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama3 8b Chat</b></a>. Meta Llama3 is the new open LLM and comes in two sizes: 8b and 70b. Feel free to play with it, or duplicate to run privately!</p>
36
- <p>🔎 For more details about the Llama3 release and how to use the model with <code>transformers</code>, take a look <a href="https://huggingface.co/blog/llama3">at our blog post</a>.</p>
37
- <p>🦕 Looking for an even more powerful model? Check out the <a href="https://huggingface.co/chat/"><b>Hugging Chat</b></a> integration for Meta Llama 3 70b</p>
38
- </div>
39
- '''
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  LICENSE = """
42
- <p/>
43
- ---
44
- Built with Meta Llama 3
45
  """
46
 
47
  PLACEHOLDER = """
48
  <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
49
- <img src="https://ysharma-dummy-chat-app.hf.space/file=/tmp/gradio/8e75e61cc9bab22b7ce3dec85ab0e6db1da5d107/Meta_lockup_positive%20primary_RGB.jpg" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
50
- <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Meta llama3</h1>
51
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
52
  </div>
53
  """
@@ -58,34 +52,28 @@ h1 {
58
  text-align: center;
59
  display: block;
60
  }
61
- #duplicate-button {
62
- margin: auto;
63
- color: white;
64
- background: #1565c0;
65
- border-radius: 100vh;
66
- }
67
  """
68
 
69
  # Load the tokenizer and model
70
- model_name = "gradientai/Llama-3-8B-Instruct-262k"
71
  tokenizer = AutoTokenizer.from_pretrained(model_name)
72
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") # to("cuda:0")
 
 
 
 
 
 
 
 
73
 
74
- from minference import MInference
75
- minference_patch = MInference("minference", model_name)
76
- model = minference_patch(model)
77
 
78
- terminators = [
79
- tokenizer.eos_token_id,
80
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
81
- ]
82
 
83
  @spaces.GPU(duration=120)
84
- def chat_llama3_8b(message: str,
85
- history: list,
86
- temperature: float,
87
- max_new_tokens: int
88
- ) -> str:
89
  """
90
  Generate a streaming response using the llama3-8b model.
91
  Args:
@@ -99,81 +87,86 @@ def chat_llama3_8b(message: str,
99
  # global model
100
  conversation = []
101
  for user, assistant in history:
102
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
 
 
 
 
103
  conversation.append({"role": "user", "content": message})
104
 
105
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
106
- print(model.device)
107
- # subprocess.run(
108
- # "pip install pycuda==2023.1",
109
- # shell=True,
110
- # )
111
- # if "has_patch" not in model.__dict__:
112
- # from minference import MInference
113
- # minference_patch = MInference("minference", model_name)
114
- # model = minference_patch(model)
115
-
116
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
117
 
118
  generate_kwargs = dict(
119
- input_ids= input_ids,
120
  streamer=streamer,
121
  max_new_tokens=max_new_tokens,
122
  do_sample=True,
123
  temperature=temperature,
124
  eos_token_id=terminators,
125
  )
126
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
127
  if temperature == 0:
128
- generate_kwargs['do_sample'] = False
129
-
130
  t = Thread(target=model.generate, kwargs=generate_kwargs)
131
  t.start()
132
 
133
  outputs = []
134
  for text in streamer:
135
  outputs.append(text)
136
- #print(outputs)
137
  yield "".join(outputs)
138
-
139
 
140
  # Gradio block
141
- chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
142
 
143
  with gr.Blocks(fill_height=True, css=css) as demo:
144
-
145
  gr.Markdown(DESCRIPTION)
146
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
147
  gr.ChatInterface(
148
  fn=chat_llama3_8b,
149
  chatbot=chatbot,
150
  fill_height=True,
151
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
 
 
152
  additional_inputs=[
153
- gr.Slider(minimum=0,
154
- maximum=1,
155
- step=0.1,
156
- value=0.95,
157
- label="Temperature",
158
- render=False),
159
- gr.Slider(minimum=128,
160
- maximum=4096,
161
- step=1,
162
- value=512,
163
- label="Max new tokens",
164
- render=False ),
165
- ],
 
 
 
 
166
  examples=[
167
- ['How to setup a human base on Mars? Give short answer.'],
168
- ['Explain theory of relativity to me like I’m 8 years old.'],
169
- ['What is 9,000 * 9,000?'],
170
- ['Write a pun-filled happy birthday message to my friend Alex.'],
171
- ['Justify why a penguin might make a good king of the jungle.']
172
- ],
173
  cache_examples=False,
174
- )
175
-
176
  gr.Markdown(LICENSE)
177
-
178
  if __name__ == "__main__":
179
- demo.launch()
 
1
  import subprocess
2
  import os
3
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import gradio as gr
6
  import os
7
  import spaces
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
  from threading import Thread
10
+ from transformers.utils.import_utils import _is_package_available
11
 
12
  # Set an environment variable
13
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
 
15
 
16
+ DESCRIPTION = """
17
+ # MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention (Under Review) [[paper](https://arxiv.org/abs/2406.05736)]
18
+ _Huiqiang Jiang†, Yucheng Li†, Chengruidong Zhang†, Qianhui Wu, Xufang Luo, Surin Ahn, Zhenhua Han, Amir H. Abdi, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
19
+
20
+ <h2 style="text-align: center;"><a href="https://github.com/microsoft/MInference" target="blank"> [Code]</a>
21
+ <a href="https://hqjiang.com/minference.html" target="blank"> [Project Page]</a>
22
+ <a href="https://arxiv.org/abs/2406.05736" target="blank"> [Paper]</a></h2>
23
+
24
+ <font color="brown"><b>This is only a deployment demo. Due to limited GPU resources, we do not provide an online demo. You will need to follow the code below to try MInference locally.</b></font>
25
+
26
+ ```bash
27
+ git clone https://huggingface.co/spaces/microsoft/MInference
28
+ cd MInference
29
+ pip install -r requirments.txt
30
+ pip install flash_attn pycuda==2023.1
31
+ python app.py
32
+ ```
33
+ <br/>
34
+ """
35
 
36
  LICENSE = """
37
+ <div style="text-align: center;">
38
+ <p>© 2024 Microsoft</p>
39
+ </div>
40
  """
41
 
42
  PLACEHOLDER = """
43
  <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
44
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaMA-3-8B-Gradient-1M w/ MInference</h1>
 
45
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
46
  </div>
47
  """
 
52
  text-align: center;
53
  display: block;
54
  }
 
 
 
 
 
 
55
  """
56
 
57
  # Load the tokenizer and model
58
+ model_name = "gradientai/Llama-3-8B-Instruct-Gradient-1048k"
59
  tokenizer = AutoTokenizer.from_pretrained(model_name)
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ model_name, torch_dtype="auto", device_map="auto"
62
+ ) # to("cuda:0")
63
+
64
+ if torch.cuda.is_available() and _is_package_available("pycuda"):
65
+ from minference import MInference
66
+
67
+ minference_patch = MInference("minference", model_name)
68
+ model = minference_patch(model)
69
 
70
+ terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
 
 
71
 
 
 
 
 
72
 
73
  @spaces.GPU(duration=120)
74
+ def chat_llama3_8b(
75
+ message: str, history: list, temperature: float, max_new_tokens: int
76
+ ) -> str:
 
 
77
  """
78
  Generate a streaming response using the llama3-8b model.
79
  Args:
 
87
  # global model
88
  conversation = []
89
  for user, assistant in history:
90
+ conversation.extend(
91
+ [
92
+ {"role": "user", "content": user},
93
+ {"role": "assistant", "content": assistant},
94
+ ]
95
+ )
96
  conversation.append({"role": "user", "content": message})
97
 
98
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(
99
+ model.device
100
+ )
101
+
102
+ streamer = TextIteratorStreamer(
103
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
104
+ )
 
 
 
 
 
105
 
106
  generate_kwargs = dict(
107
+ input_ids=input_ids,
108
  streamer=streamer,
109
  max_new_tokens=max_new_tokens,
110
  do_sample=True,
111
  temperature=temperature,
112
  eos_token_id=terminators,
113
  )
114
+ # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
115
  if temperature == 0:
116
+ generate_kwargs["do_sample"] = False
117
+
118
  t = Thread(target=model.generate, kwargs=generate_kwargs)
119
  t.start()
120
 
121
  outputs = []
122
  for text in streamer:
123
  outputs.append(text)
124
+ # print(outputs)
125
  yield "".join(outputs)
126
+
127
 
128
  # Gradio block
129
+ chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label="Gradio ChatInterface")
130
 
131
  with gr.Blocks(fill_height=True, css=css) as demo:
132
+
133
  gr.Markdown(DESCRIPTION)
 
134
  gr.ChatInterface(
135
  fn=chat_llama3_8b,
136
  chatbot=chatbot,
137
  fill_height=True,
138
+ additional_inputs_accordion=gr.Accordion(
139
+ label="⚙️ Parameters", open=False, render=False
140
+ ),
141
  additional_inputs=[
142
+ gr.Slider(
143
+ minimum=0,
144
+ maximum=1,
145
+ step=0.1,
146
+ value=0.95,
147
+ label="Temperature",
148
+ render=False,
149
+ ),
150
+ gr.Slider(
151
+ minimum=128,
152
+ maximum=4096,
153
+ step=1,
154
+ value=512,
155
+ label="Max new tokens",
156
+ render=False,
157
+ ),
158
+ ],
159
  examples=[
160
+ ["How to setup a human base on Mars? Give short answer."],
161
+ ["Explain theory of relativity to me like I’m 8 years old."],
162
+ ["What is 9,000 * 9,000?"],
163
+ ["Write a pun-filled happy birthday message to my friend Alex."],
164
+ ["Justify why a penguin might make a good king of the jungle."],
165
+ ],
166
  cache_examples=False,
167
+ )
168
+
169
  gr.Markdown(LICENSE)
170
+
171
  if __name__ == "__main__":
172
+ demo.launch(share=False)
requirements.txt CHANGED
@@ -2,5 +2,4 @@ triton==2.1.0
2
  accelerate
3
  transformers
4
  wheel
5
- setuptools
6
- pycuda==2023.1
 
2
  accelerate
3
  transformers
4
  wheel
5
+ setuptools