BK-Lee commited on
Commit
63c3081
1 Parent(s): f443686
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -22,6 +22,9 @@ from torchvision.transforms.functional import pil_to_tensor
22
  # accel
23
  accel = Accelerator()
24
 
 
 
 
25
  # loading model
26
  model_1_8, tokenizer_1_8 = load_model(size='1.8b')
27
 
@@ -50,7 +53,10 @@ def threading_function(inputs, streamer, device, model, tokenizer, temperature,
50
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
51
 
52
  # model selection
53
- if "1.8B" in link:
 
 
 
54
  model = model_1_8
55
  tokenizer = tokenizer_1_8
56
  elif "3.8B" in link:
@@ -127,7 +133,7 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
127
  yield buffer
128
 
129
  demo = gr.ChatInterface(fn=bot_streaming,
130
- additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
131
  additional_inputs_accordion="Generation Hyperparameters",
132
  theme=gr.themes.Soft(),
133
  title="Phantom",
 
22
  # accel
23
  accel = Accelerator()
24
 
25
+ # loading model
26
+ model_0_5, tokenizer_0_5 = load_model(size='0.5b')
27
+
28
  # loading model
29
  model_1_8, tokenizer_1_8 = load_model(size='1.8b')
30
 
 
53
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
54
 
55
  # model selection
56
+ if "0.5B" in link:
57
+ model = model_0_5
58
+ tokenizer = tokenizer_0_5
59
+ elif "1.8B" in link:
60
  model = model_1_8
61
  tokenizer = tokenizer_1_8
62
  elif "3.8B" in link:
 
133
  yield buffer
134
 
135
  demo = gr.ChatInterface(fn=bot_streaming,
136
+ additional_inputs = [gr.Radio(["0.5B", "1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
137
  additional_inputs_accordion="Generation Hyperparameters",
138
  theme=gr.themes.Soft(),
139
  title="Phantom",