import torch import transformers from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr import os model_name = 'eliolio/bart-finetuned-yelpreviews' access_token = os.environ.get('private_token') model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token) tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token) def create_prompt(stars, useful, funny, cool): return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}" def generate_reviews(stars, useful, funny, cool): text = create_prompt(stars, useful, funny, cool) inputs = tokenizer(text, return_tensors='pt') out = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, num_beams=5, num_return_sequences=3 ) reviews = [] for review in out: reviews.append(tokenizer.decode(review, skip_special_tokens=True)) return reviews[0], reviews[1], reviews[2] css = """ #ctr {text-align: center;} #btn {color: white; background: linear-gradient(90deg, #00d2ff 0%, #3a47d5 100%);} """ md_text = """## Generating Yelp reviews with BART-base ⭐⭐⭐""" demo = gr.Blocks(css=css) with demo: with gr.Row(): gr.Markdown(md_text, elem_id='ctr') with gr.Row(): stars = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="stars") useful = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="useful") funny = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="funny") cool = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="cool") with gr.Row(): button = gr.Button("Generate reviews !", elem_id='btn') with gr.Row(): output1 = gr.Textbox(label="Review #1") output2 = gr.Textbox(label="Review #2") output3 = gr.Textbox(label="Review #3") button.click( fn=generate_reviews, inputs=[stars, useful, funny, cool], outputs=[output1, output2, output3] ) demo.launch()