Spaces:
Runtime error
Runtime error
HakanKilic01
commited on
Commit
•
10c8213
1
Parent(s):
add4b49
07/01/23-15:27
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
4 |
|
5 |
|
6 |
|
7 |
-
def
|
8 |
pattern = r"'''py\n(.*?)'''"
|
9 |
match = re.search(pattern, input_text, re.DOTALL)
|
10 |
|
@@ -12,18 +12,27 @@ def extract_code(input_text):
|
|
12 |
return match.group(1)
|
13 |
else:
|
14 |
return None # Return None if no match is found
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def generate_code(input_text,modelName):
|
17 |
if(modelName == "codegen-350M"):
|
18 |
input_ids = codeGenTokenizer(input_text, return_tensors="pt").input_ids
|
19 |
generated_ids = codeGenModel.generate(input_ids, max_length=128)
|
20 |
result = codeGenTokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
21 |
-
return
|
22 |
elif(modelName == "mistral-7b"):
|
23 |
input_ids = mistralTokenizer(generate_prompt_mistral(input_text), return_tensors="pt").input_ids
|
24 |
generated_ids = mistralModel.generate(input_ids, max_length=128)
|
25 |
result = mistralTokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
26 |
-
return result
|
27 |
else:
|
28 |
return None
|
29 |
|
|
|
4 |
|
5 |
|
6 |
|
7 |
+
def extract_code_codegen(input_text):
|
8 |
pattern = r"'''py\n(.*?)'''"
|
9 |
match = re.search(pattern, input_text, re.DOTALL)
|
10 |
|
|
|
12 |
return match.group(1)
|
13 |
else:
|
14 |
return None # Return None if no match is found
|
15 |
+
|
16 |
+
def extract_code_mistral(input_text):
|
17 |
+
pattern = r'\[CODE\](.*?)\[/CODE\]'
|
18 |
+
match = re.search(pattern, input_text, re.DOTALL)
|
19 |
+
|
20 |
+
if match:
|
21 |
+
return match.group(1)
|
22 |
+
else:
|
23 |
+
return None # Return None if no match is found
|
24 |
|
25 |
def generate_code(input_text,modelName):
|
26 |
if(modelName == "codegen-350M"):
|
27 |
input_ids = codeGenTokenizer(input_text, return_tensors="pt").input_ids
|
28 |
generated_ids = codeGenModel.generate(input_ids, max_length=128)
|
29 |
result = codeGenTokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
30 |
+
return extract_code_codegen(result)
|
31 |
elif(modelName == "mistral-7b"):
|
32 |
input_ids = mistralTokenizer(generate_prompt_mistral(input_text), return_tensors="pt").input_ids
|
33 |
generated_ids = mistralModel.generate(input_ids, max_length=128)
|
34 |
result = mistralTokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
35 |
+
return extract_code_mistral(result)
|
36 |
else:
|
37 |
return None
|
38 |
|