HakanKilic01 commited on
Commit
10c8213
1 Parent(s): add4b49

07/01/23-15:27

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -4,7 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
 
6
 
7
- def extract_code(input_text):
8
  pattern = r"'''py\n(.*?)'''"
9
  match = re.search(pattern, input_text, re.DOTALL)
10
 
@@ -12,18 +12,27 @@ def extract_code(input_text):
12
  return match.group(1)
13
  else:
14
  return None # Return None if no match is found
 
 
 
 
 
 
 
 
 
15
 
16
  def generate_code(input_text,modelName):
17
  if(modelName == "codegen-350M"):
18
  input_ids = codeGenTokenizer(input_text, return_tensors="pt").input_ids
19
  generated_ids = codeGenModel.generate(input_ids, max_length=128)
20
  result = codeGenTokenizer.decode(generated_ids[0], skip_special_tokens=True)
21
- return extract_code(result)
22
  elif(modelName == "mistral-7b"):
23
  input_ids = mistralTokenizer(generate_prompt_mistral(input_text), return_tensors="pt").input_ids
24
  generated_ids = mistralModel.generate(input_ids, max_length=128)
25
  result = mistralTokenizer.decode(generated_ids[0], skip_special_tokens=True)
26
- return result
27
  else:
28
  return None
29
 
 
4
 
5
 
6
 
7
+ def extract_code_codegen(input_text):
8
  pattern = r"'''py\n(.*?)'''"
9
  match = re.search(pattern, input_text, re.DOTALL)
10
 
 
12
  return match.group(1)
13
  else:
14
  return None # Return None if no match is found
15
+
16
+ def extract_code_mistral(input_text):
17
+ pattern = r'\[CODE\](.*?)\[/CODE\]'
18
+ match = re.search(pattern, input_text, re.DOTALL)
19
+
20
+ if match:
21
+ return match.group(1)
22
+ else:
23
+ return None # Return None if no match is found
24
 
25
  def generate_code(input_text,modelName):
26
  if(modelName == "codegen-350M"):
27
  input_ids = codeGenTokenizer(input_text, return_tensors="pt").input_ids
28
  generated_ids = codeGenModel.generate(input_ids, max_length=128)
29
  result = codeGenTokenizer.decode(generated_ids[0], skip_special_tokens=True)
30
+ return extract_code_codegen(result)
31
  elif(modelName == "mistral-7b"):
32
  input_ids = mistralTokenizer(generate_prompt_mistral(input_text), return_tensors="pt").input_ids
33
  generated_ids = mistralModel.generate(input_ids, max_length=128)
34
  result = mistralTokenizer.decode(generated_ids[0], skip_special_tokens=True)
35
+ return extract_code_mistral(result)
36
  else:
37
  return None
38