Henrychur commited on
Commit
dee6b79
1 Parent(s): 01c4a7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -2
app.py CHANGED
@@ -2,6 +2,32 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class MedS_Llama3:
6
  def __init__(self, model_path: str):
7
  # 加载模型到CPU
@@ -19,10 +45,21 @@ class MedS_Llama3:
19
  )
20
  self.tokenizer.pad_token = self.tokenizer.eos_token
21
  self.model.eval()
 
22
  print('Model and tokenizer loaded on CPU!')
23
-
 
 
 
 
 
 
 
24
  def chat(self, query: str, instruction: str, max_output_tokens: int) -> str:
25
- input_sentence = f"{instruction}\n\n{query}"
 
 
 
26
  input_tokens = self.tokenizer(
27
  input_sentence,
28
  return_tensors="pt",
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+ from typing import List, Literal, Sequence, TypedDict
6
+
7
+ Role = Literal["system", "user", "assistant"]
8
+
9
+ class Message(TypedDict):
10
+ role: Role
11
+ content: str
12
+
13
+ Dialog = Sequence[Message]
14
+
15
+ class ChatFormat:
16
+ def encode_header(self, message: Message) -> str:
17
+ return f"{message['role']}\n\n"
18
+
19
+ def encode_message(self, message: Message) -> str:
20
+ header = self.encode_header(message)
21
+ return f"{header}{message['content'].strip()}"
22
+
23
+ def encode_dialog_prompt(self, dialog: Dialog) -> str:
24
+ dialog_str = ""
25
+ for message in dialog:
26
+ dialog_str += self.encode_message(message)
27
+ dialog_str += self.encode_header({"role": "assistant", "content": ""})
28
+ return dialog_str
29
+
30
+
31
  class MedS_Llama3:
32
  def __init__(self, model_path: str):
33
  # 加载模型到CPU
 
45
  )
46
  self.tokenizer.pad_token = self.tokenizer.eos_token
47
  self.model.eval()
48
+ self.prompt_engine = ChatFormat()
49
  print('Model and tokenizer loaded on CPU!')
50
+
51
+ def __build_inputs_for_llama3(self, query: str, instruction: str) -> str:
52
+ input_ss = [
53
+ {"role": 'system', "content": instruction},
54
+ {"role": 'user', "content": query}
55
+ ]
56
+ return self.prompt_engine.encode_dialog_prompt(input_ss)
57
+
58
  def chat(self, query: str, instruction: str, max_output_tokens: int) -> str:
59
+
60
+ formatted_query = f"Input:\n{query}\nOutput:\n"
61
+ input_sentence = self.__build_inputs_for_llama3(formatted_query, instruction)
62
+
63
  input_tokens = self.tokenizer(
64
  input_sentence,
65
  return_tensors="pt",