gregH commited on
Commit
85f7114
1 Parent(s): cd90cd9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
+ import time
5
+ import numpy as np
6
+ from torch.nn import functional as F
7
+ import os
8
+ from threading import Thread
9
+
10
+ print(f"Starting to load the model to memory")
11
+ m = AutoModelForCausalLM.from_pretrained(
12
+ "stabilityai/stablelm-2-zephyr-1_6b", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True)
13
+ tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-zephyr-1_6b", trust_remote_code=True)
14
+ # using CUDA for an optimal experience
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ m = m.to(device)
17
+ print(f"Sucessfully loaded the model to the memory")
18
+
19
+
20
+ start_message = ""
21
+
22
+ def user(message, history):
23
+ # Append the user's message to the conversation history
24
+ return "", history + [[message, ""]]
25
+
26
+
27
+ def chat(message, history):
28
+ chat = []
29
+ for item in history:
30
+ chat.append({"role": "user", "content": item[0]})
31
+ if item[1] is not None:
32
+ chat.append({"role": "assistant", "content": item[1]})
33
+ chat.append({"role": "user", "content": message})
34
+ messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
35
+ # Tokenize the messages string
36
+ model_inputs = tok([messages], return_tensors="pt").to(device)
37
+ streamer = TextIteratorStreamer(
38
+ tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
39
+ generate_kwargs = dict(
40
+ model_inputs,
41
+ streamer=streamer,
42
+ max_new_tokens=1024,
43
+ do_sample=True,
44
+ top_p=0.95,
45
+ top_k=1000,
46
+ temperature=0.75,
47
+ num_beams=1,
48
+ )
49
+ t = Thread(target=m.generate, kwargs=generate_kwargs)
50
+ t.start()
51
+
52
+ # Initialize an empty string to store the generated text
53
+ partial_text = ""
54
+ for new_text in streamer:
55
+ # print(new_text)
56
+ partial_text += new_text
57
+ # Yield an empty string to cleanup the message textbox and the updated conversation history
58
+ yield partial_text
59
+
60
+ demo = gr.ChatInterface(fn=chat, examples=["hello", "hola", "merhaba"], title="Stable LM 2 Zephyr 1.6b")
61
+ demo.launch()