|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import re |
|
from resources import banner, error_html_response |
|
import logging |
|
logging.basicConfig(format='%(asctime)s: [%(levelname)s]: %(message)s', level=logging.INFO) |
|
|
|
model_checkpoint = 'gastronomia-para-to2/gastronomia_para_to2' |
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
|
model = AutoModelForCausalLM.from_pretrained(model_checkpoint) |
|
|
|
special_tokens = [ |
|
'<INPUT_START>', |
|
'<NEXT_INPUT>', |
|
'<INPUT_END>', |
|
'<TITLE_START>', |
|
'<TITLE_END>', |
|
'<INGR_START>', |
|
'<NEXT_INGR>', |
|
'<INGR_END>', |
|
'<INSTR_START>', |
|
'<NEXT_INSTR>', |
|
'<INSTR_END>'] |
|
|
|
def frame_html_response(html_response): |
|
return f"""<iframe style="width: 100%; height: 800px" name="result" allow="midi; geolocation; microphone; camera; |
|
display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
|
allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{html_response}'></iframe>""" |
|
|
|
|
|
def check_special_tokens_order(pre_output): |
|
return (pre_output.find('<INPUT_START>') < |
|
pre_output.find('<NEXT_INPUT>') <= |
|
pre_output.rfind('<NEXT_INPUT>') < |
|
pre_output.find('<INPUT_END>') < |
|
pre_output.find('<INGR_START>') < |
|
pre_output.find('<NEXT_INGR>') <= |
|
pre_output.rfind('<NEXT_INGR>') < |
|
pre_output.find('<INGR_END>') < |
|
pre_output.find('<INSTR_START>') < |
|
pre_output.find('<NEXT_INSTR>') <= |
|
pre_output.rfind('<NEXT_INSTR>') < |
|
pre_output.find('<INSTR_END>') < |
|
pre_output.find('<TITLE_START>') < |
|
pre_output.find('<TITLE_END>')) |
|
|
|
|
|
def make_html_response(title, ingredients, instructions): |
|
ingredients_html_list = '<ul><li>' + '</li><li>'.join(ingredients) + '</li></ul>' |
|
instructions_html_list = '<ol><li>' + '</li><li>'.join(instructions) + '</li></ol>' |
|
|
|
html_response = f''' |
|
<!DOCTYPE html> |
|
<html> |
|
<body> |
|
<h1>{title}</h1> |
|
|
|
<h2>Ingredientes</h2> |
|
{ingredients_html_list} |
|
|
|
<h2>Instrucciones</h2> |
|
{instructions_html_list} |
|
|
|
</body> |
|
</html> |
|
''' |
|
return html_response |
|
|
|
|
|
def rerun_model_output(pre_output): |
|
if pre_output is None: |
|
return True |
|
elif not '<RECIPE_END>' in pre_output: |
|
logging.info('<RECIPE_END> not in pre_output') |
|
return True |
|
pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')] |
|
if not all(special_token in pre_output_trimmed for special_token in special_tokens): |
|
logging.info('Not all special tokens are in preoutput') |
|
return True |
|
elif not check_special_tokens_order(pre_output_trimmed): |
|
logging.info('Special tokens are unordered in preoutput') |
|
return True |
|
elif len(pre_output_trimmed.split())<75: |
|
logging.info('Length of the recipe is <75') |
|
return True |
|
else: |
|
return False |
|
|
|
def check_wrong_ingredients(ingredients): |
|
new_ingredients = [] |
|
for ingredient in ingredients: |
|
if ingredient.startswith('De '): |
|
new_ingredients.append(ingredient.strip('De ').capitalize()) |
|
else: |
|
new_ingredients.append(ingredient) |
|
return new_ingredients |
|
|
|
|
|
def make_recipe(input_ingredients): |
|
logging.info(f'Received inputs: {input_ingredients}') |
|
input_ingredients = re.sub(' y ', ', ', input_ingredients) |
|
input = '<RECIPE_START> ' |
|
input += '<INPUT_START> ' + ' <NEXT_INPUT> '.join(input_ingredients.split(', ')) + ' <INPUT_END> ' |
|
input += '<INGR_START> ' |
|
tokenized_input = tokenizer(input, return_tensors='pt') |
|
|
|
pre_output = None |
|
i = 0 |
|
while rerun_model_output(pre_output): |
|
if i == 3: |
|
return frame_html_response(error_html_response) |
|
output = model.generate(**tokenized_input, |
|
max_length=600, |
|
do_sample=True, |
|
top_p=0.92, |
|
top_k=50, |
|
|
|
num_return_sequences=3) |
|
pre_output = tokenizer.decode(output[0], skip_special_tokens=False) |
|
i += 1 |
|
pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')] |
|
output_ingredients = re.search('<INGR_START> (.*) <INGR_END>', pre_output_trimmed).group(1) |
|
output_ingredients = output_ingredients.split(' <NEXT_INGR> ') |
|
output_ingredients = list(set([output_ingredient.strip() for output_ingredient in output_ingredients])) |
|
output_ingredients = [output_ing.capitalize() for output_ing in output_ingredients] |
|
output_ingredients = check_wrong_ingredients(output_ingredients) |
|
output_title = re.search('<TITLE_START> (.*) <TITLE_END>', pre_output_trimmed).group(1).strip().capitalize() |
|
output_instructions = re.search('<INSTR_START> (.*) <INSTR_END>', pre_output_trimmed).group(1) |
|
output_instructions = output_instructions.split(' <NEXT_INSTR> ') |
|
|
|
html_response = make_html_response(output_title, output_ingredients, output_instructions) |
|
|
|
return frame_html_response(html_response) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=make_recipe, |
|
inputs= |
|
[ |
|
gr.inputs.Textbox(lines=1, placeholder='ingrediente_1, ingrediente_2, ..., ingrediente_n', |
|
label='Dime con qué ingredientes quieres que cocinemos hoy y te sugeriremos una receta tan pronto como nuestros fogones estén libres'), |
|
], |
|
outputs= |
|
[ |
|
gr.outputs.HTML(label="¡Esta es mi propuesta para ti! ¡Buen provecho!") |
|
], |
|
examples= |
|
[ |
|
['salmón, zumo de naranja, aceite de oliva, sal, pimienta'], |
|
['harina, azúcar, huevos, chocolate, levadura Royal'] |
|
], |
|
description=banner) |
|
iface.launch(enable_queue=True) |
|
|