cryptocalypse commited on
Commit
3287a07
1 Parent(s): bdaa284

Update gen.py

Browse files
Files changed (1) hide show
  1. gen.py +22 -14
gen.py CHANGED
@@ -1,10 +1,14 @@
1
  import torch
2
- from transformers import pipeline
3
  import sys
4
  import sys
 
 
 
 
 
 
 
5
 
6
- # Cargar el pipeline de generación de texto
7
- pipe = pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-alpha", torch_dtype=torch.bfloat16, device_map="auto")
8
 
9
 
10
  # Definir el prompt para generar un JSON con eventos anidados
@@ -149,16 +153,20 @@ prompt = (
149
 
150
  def generate(event):
151
  # Generar el texto usando el modelo
152
- # We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
153
- messages = [
154
- {
155
- "role": "system",
156
- "content": prompt,
157
- },
158
- {"role": "user", "content": event},
159
- ]
160
- prompt_tuned = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
161
- outputs = pipe(prompt_tuned, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
 
 
 
 
162
 
163
  # Imprimir la salida generada
164
- return outputs
 
1
  import torch
 
2
  import sys
3
  import sys
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained('stabilityai/stablelm-2-zephyr-1_6b')
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ 'stabilityai/stablelm-2-zephyr-1_6b',
9
+ device_map="auto"
10
+ )
11
 
 
 
12
 
13
 
14
  # Definir el prompt para generar un JSON con eventos anidados
 
153
 
154
  def generate(event):
155
  # Generar el texto usando el modelo
156
+ prompt = [{'role':'system','content':event},{'role': 'user', 'content': 'Which famous math number begins with 1.6 ...?'}]
157
+ inputs = tokenizer.apply_chat_template(
158
+ prompt,
159
+ add_generation_prompt=True,
160
+ return_tensors='pt'
161
+ )
162
+
163
+ tokens = model.generate(
164
+ inputs.to(model.device),
165
+ max_new_tokens=1024,
166
+ temperature=0.5,
167
+ do_sample=True
168
+ )
169
+
170
 
171
  # Imprimir la salida generada
172
+ return tokenizer.decode(tokens[0], skip_special_tokens=False)