chrisjay commited on
Commit
8ffc092
β€’
1 Parent(s): 0d9bc50

work on dialogue models

Browse files
Files changed (3) hide show
  1. README.md +3 -8
  2. app.py +4 -2
  3. model.py +16 -15
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Llama 2 13b Chat
3
- emoji: πŸ¦™
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
@@ -8,12 +8,7 @@ sdk_version: 3.37.0
8
  app_file: app.py
9
  pinned: false
10
  license: other
11
- suggested_hardware: a10g-small
12
  duplicated_from: huggingface-projects/llama-2-13b-chat
13
  ---
14
 
15
- # LLAMA v2 Models
16
-
17
- Llama v2 was introduced in [this paper](https://arxiv.org/abs/2307.09288).
18
-
19
- This Space demonstrates [Llama-2-13b-chat-hf](meta-llama/Llama-2-13b-chat-hf) from Meta. Please, check the original model card for details.
 
1
  ---
2
+ title: Chat with Masakhane Dialogue Models
3
+ emoji: 🌍
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
 
8
  app_file: app.py
9
  pinned: false
10
  license: other
 
11
  duplicated_from: huggingface-projects/llama-2-13b-chat
12
  ---
13
 
14
+ # Chat with Masakhane Dialogue Models
 
 
 
 
app.py CHANGED
@@ -12,7 +12,9 @@ MAX_INPUT_TOKEN_LENGTH = 4000
12
 
13
  DESCRIPTION = """
14
  # Masakhane Dialogue Models
15
- This Space demonstrates the dialogue models for Nigerian Pidgin, an African langage.\
 
 
16
  πŸ”Ž For more about visit [our homepage](https://www.masakhane.io/).
17
 
18
  """
@@ -21,7 +23,7 @@ This Space demonstrates the dialogue models for Nigerian Pidgin, an African lang
21
 
22
 
23
  if not torch.cuda.is_available():
24
- DESCRIPTION += '\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>'
25
 
26
 
27
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
 
12
 
13
  DESCRIPTION = """
14
  # Masakhane Dialogue Models
15
+
16
+ This Space demonstrates the dialogue models for Nigerian Pidgin, an African langage.\n
17
+
18
  πŸ”Ž For more about visit [our homepage](https://www.masakhane.io/).
19
 
20
  """
 
23
 
24
 
25
  if not torch.cuda.is_available():
26
+ DESCRIPTION += '\n<p>Running on CPU πŸ₯Ά This demo will be very slow on CPU.</p>'
27
 
28
 
29
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
model.py CHANGED
@@ -5,18 +5,18 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIt
5
 
6
  model_id = 'tosin/dialogpt_afriwoz_pidgin'
7
 
8
- if torch.cuda.is_available():
9
- config = AutoConfig.from_pretrained(model_id)
10
- config.pretraining_tp = 1
11
- model = AutoModelForCausalLM.from_pretrained(
12
- model_id,
13
- config=config,
14
- torch_dtype=torch.float16,
15
- load_in_4bit=True,
16
- device_map='auto'
17
  )
18
- else:
19
- model = None
20
  tokenizer = AutoTokenizer.from_pretrained(model_id)
21
 
22
 
@@ -51,10 +51,11 @@ def run(message: str,
51
  top_p: float = 0.95,
52
  top_k: int = 50) -> Iterator[str]:
53
  prompt = get_prompt(message, chat_history, system_prompt)
54
- inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
 
55
 
56
  streamer = TextIteratorStreamer(tokenizer,
57
- timeout=10.,
58
  skip_prompt=True,
59
  skip_special_tokens=True)
60
  generate_kwargs = dict(
@@ -62,8 +63,8 @@ def run(message: str,
62
  streamer=streamer,
63
  max_new_tokens=max_new_tokens,
64
  do_sample=True,
65
- top_p=top_p,
66
- top_k=top_k,
67
  temperature=temperature,
68
  num_beams=1,
69
  )
 
5
 
6
  model_id = 'tosin/dialogpt_afriwoz_pidgin'
7
 
8
+ #if torch.cuda.is_available():
9
+ config = AutoConfig.from_pretrained(model_id)
10
+ config.pretraining_tp = 1
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_id,
13
+ config=config,
14
+ #torch_dtype=torch.float16,
15
+ #load_in_4bit=True,
16
+ device_map='cpu'
17
  )
18
+ #else:
19
+ # model = None
20
  tokenizer = AutoTokenizer.from_pretrained(model_id)
21
 
22
 
 
51
  top_p: float = 0.95,
52
  top_k: int = 50) -> Iterator[str]:
53
  prompt = get_prompt(message, chat_history, system_prompt)
54
+ #inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
55
+ inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False)
56
 
57
  streamer = TextIteratorStreamer(tokenizer,
58
+ timeout=40.,
59
  skip_prompt=True,
60
  skip_special_tokens=True)
61
  generate_kwargs = dict(
 
63
  streamer=streamer,
64
  max_new_tokens=max_new_tokens,
65
  do_sample=True,
66
+ #top_p=top_p,
67
+ #top_k=top_k,
68
  temperature=temperature,
69
  num_beams=1,
70
  )