File size: 2,363 Bytes
7f9f96d
 
 
 
e4f695b
7f9f96d
 
 
 
 
e4f695b
7f9f96d
e4f695b
 
7f9f96d
e4f695b
 
 
 
 
 
 
 
 
7f9f96d
 
e4f695b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9f96d
e4f695b
 
 
 
 
 
 
 
7f9f96d
e4f695b
 
 
 
 
 
 
 
 
 
7f9f96d
e4f695b
 
7f9f96d
e4f695b
 
 
 
 
7f9f96d
e4f695b
 
b878b56
 
 
 
3ccf1a4
 
 
 
b878b56
e4f695b
b878b56
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from threading import Thread
from typing import Iterator

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig

model_id = 'baichuan-inc/Baichuan2-13B-Chat'

if torch.cuda.is_available():
  model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # device_map='auto',
    torch_dtype=torch.float16,
    trust_remote_code=True
  )
  model = model.quantize(4).cuda()
  model.generation_config = GenerationConfig.from_pretrained(model_id)
else:
  model = None
tokenizer = AutoTokenizer.from_pretrained(
  model_id,
  use_fast=False,
  trust_remote_code=True
)

def get_prompt(
  message: str,
  chat_history: list[tuple[str, str]],
  system_prompt: str
) -> str:
  texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
  # The first user input is _not_ stripped
  do_strip = False
  for user_input, response in chat_history:
    user_input = user_input.strip() if do_strip else user_input
    do_strip = True
    texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
  message = message.strip() if do_strip else message
  texts.append(f'{message} [/INST]')
  return ''.join(texts)

def get_input_token_length(
  message: str,
  chat_history: list[tuple[str, str]],
  system_prompt: str
) -> int:
  prompt = get_prompt(message, chat_history, system_prompt)
  input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
  return input_ids.shape[-1]

def run(
  message: str,
  chat_history: list[tuple[str, str]],
  system_prompt: str,
  max_new_tokens: int = 1024,
  temperature: float = 1.0,
  top_p: float = 0.95,
  top_k: int = 5
) -> Iterator[str]:
  print(chat_history)

  history = []
  result=""

  for i in chat_history:
    history.append({"role": "user", "content": i[0]})
    history.append({"role": "assistant", "content": i[1]})
  
  print(history)

  history.append({"role": "user", "content": message})
  
  for response in model.chat(
    tokenizer,
    history,
    # stream=True,
    # max_new_tokens=max_new_tokens,
    # temperature=temperature,
    # top_p=top_p,
    # top_k=top_k,
  ):
    print(response)
    result = result + response
    yield result
    # if "content" in response["choices"][0]["delta"]:
    #   result = result + response["choices"][0]["delta"]["content"]
    #   yield result