Update koalpaca.py

#8
by 4n3mone - opened
Files changed (1) hide show
  1. koalpaca.py +2 -1
koalpaca.py CHANGED
@@ -7,6 +7,7 @@ class KoAlpaca(Model):
7
  def __init__(self):
8
  peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B"
9
  config = PeftConfig.from_pretrained(peft_model_id)
 
10
  self.bnb_config = BitsAndBytesConfig(
11
  load_in_4bit=True,
12
  bnb_4bit_use_double_quant=True,
@@ -28,7 +29,7 @@ class KoAlpaca(Model):
28
  inputs,
29
  return_tensors='pt',
30
  return_token_type_ids=False
31
- ).to('cpu'),
32
  generation_config=self.gen_config
33
  )
34
  outputs = self.tokenizer.decode(output_ids[0]).split("### λ‹΅λ³€: ")[-1]
 
7
  def __init__(self):
8
  peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B"
9
  config = PeftConfig.from_pretrained(peft_model_id)
10
+ accelerator = Accelerator()
11
  self.bnb_config = BitsAndBytesConfig(
12
  load_in_4bit=True,
13
  bnb_4bit_use_double_quant=True,
 
29
  inputs,
30
  return_tensors='pt',
31
  return_token_type_ids=False
32
+ ).to(accelerator.device),
33
  generation_config=self.gen_config
34
  )
35
  outputs = self.tokenizer.decode(output_ids[0]).split("### λ‹΅λ³€: ")[-1]