koziev ilya commited on
Commit
e3f5def
1 Parent(s): a31c007

немного причесал код, убрал лишние манипуляции с выдачей gpt

Browse files
Files changed (1) hide show
  1. README.md +7 -6
README.md CHANGED
@@ -44,6 +44,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
 
46
  tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter")
 
47
  model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter")
48
  model.to(device)
49
 
@@ -51,8 +52,10 @@ model.to(device)
51
  # В конце добавляем символ "#"
52
  input_text = """<s>- Как тебя зовут?
53
  - Джульетта Мао #"""
54
- encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt")
55
- encoded_prompt = encoded_prompt.to(device)
 
 
56
 
57
  output_sequences = model.generate(
58
  input_ids=encoded_prompt,
@@ -63,12 +66,10 @@ output_sequences = model.generate(
63
  repetition_penalty=1.2,
64
  do_sample=True,
65
  num_return_sequences=1,
66
- pad_token_id=0
67
  )
68
 
69
- generated_sequence = output_sequences[0].tolist()
70
- text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)
71
  text = text[: text.find('</s>')]
72
- text = text[text.find('#')+1:].strip() # Результат генерации содержит входную строку, поэтому отрезаем ее до символа "#".
73
  print(text)
74
  ```
 
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
 
46
  tokenizer = AutoTokenizer.from_pretrained("inkoziev/rugpt_interpreter")
47
+ tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'})
48
  model = AutoModelForCausalLM.from_pretrained("inkoziev/rugpt_interpreter")
49
  model.to(device)
50
 
 
52
  # В конце добавляем символ "#"
53
  input_text = """<s>- Как тебя зовут?
54
  - Джульетта Мао #"""
55
+ #input_text = """<s>- Что Предтечи забрали у Предшественников?
56
+ #- Они узурпировали у них Мантию — защиту всего живого в галактике #"""
57
+
58
+ encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device)
59
 
60
  output_sequences = model.generate(
61
  input_ids=encoded_prompt,
 
66
  repetition_penalty=1.2,
67
  do_sample=True,
68
  num_return_sequences=1,
69
+ pad_token_id=tokenizer.pad_token_id,
70
  )
71
 
72
+ text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)[len(input_text)+1:]
 
73
  text = text[: text.find('</s>')]
 
74
  print(text)
75
  ```