Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
from Prediction import * | |
import os | |
from datetime import datetime | |
import re | |
import json | |
import hashlib | |
persistent_path = "/data" | |
os.environ['HF_HOME'] = os.path.join(persistent_path, ".huggingface") | |
user_input_path = os.path.join(persistent_path, 'user.jsonl') | |
secret = "2fc9ff032e027e8f23bb9fb693234899" | |
def get_md5(s): | |
md = hashlib.md5() | |
md.update(s.encode('utf-8')) | |
return md.hexdigest() | |
examples = [] | |
if os.path.exists("assets/examples.txt"): | |
with open("assets/examples.txt", "r", encoding="utf8") as file: | |
for sentence in file: | |
sentence = sentence.strip() | |
examples.append(sentence) | |
else: | |
examples = [ | |
"Games of the imagination teach us actions have consequences in a realm that can be reset.", | |
"But New Jersey farmers are retiring and all over the state, development continues to push out dwindling farmland.", | |
"He also is the Head Designer of The Design Trust so-to-speak, besides his regular job ..." | |
] | |
device = torch.device('cpu') | |
tokenizer = BertTokenizer.from_pretrained("Oliver12315/Brand_Tone_of_Voice") | |
model = BertForSequenceClassification.from_pretrained("Oliver12315/Brand_Tone_of_Voice") | |
model = model.to(device) | |
def single_sentence(sentence): | |
predictions = predict_single(sentence, tokenizer, model, device) | |
return sorted(zip(LABEL_COLUMNS, predictions), key=lambda x:x[-1], reverse=True) | |
def csv_process(csv_file, attr="content"): | |
current_time = datetime.now() | |
formatted_time = current_time.strftime("%Y_%m_%d_%H_%M_%S") | |
data = pd.read_csv(csv_file.name) | |
data = data.reset_index() | |
os.makedirs('output', exist_ok=True) | |
outputs = [] | |
predictions = predict_csv(data, attr, tokenizer, model, device) | |
output_path = f"output/prediction_Brand_Tone_of_Voice_{formatted_time}.csv" | |
predictions.to_csv(output_path) | |
outputs.append(output_path) | |
return outputs | |
def logfile_query(auth): | |
if get_md5(auth) == secret and os.path.exists(user_input_path): | |
return [user_input_path] | |
else: | |
return None | |
def check_save(fname, lname, cnum, email, oname, position): | |
errors = [] | |
valid_vars = {} | |
if not fname.strip() or not lname.strip(): | |
errors.append("Name cannot be empty") | |
elif fname.isdigit() or lname.isdigit(): | |
errors.append("Name cannot be purely numerical") | |
else: | |
valid_vars["fname"] = fname | |
valid_vars["lname"] = lname | |
valid_vars["cnum"] = '' | |
if cnum: | |
if not cnum.isdigit(): | |
errors.append("The phone number must be a pure number") | |
else: | |
valid_vars["cnum"] = cnum | |
if not email.strip(): | |
errors.append("Email cannot be empty") | |
elif not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email): | |
errors.append("Incorrect email format") | |
else: | |
valid_vars["email"] = email | |
if not oname.strip(): | |
errors.append("Organization name cannot be empty") | |
elif oname.isdigit(): | |
errors.append("Organization cannot be purely numerical") | |
else: | |
valid_vars["oname"] = oname | |
valid_vars["position"] = '' | |
if position: | |
if position.isdigit(): | |
errors.append("Position in your company cannot be purely numerical") | |
else: | |
valid_vars["position"] = position | |
if errors: | |
return errors | |
current_time = datetime.now() | |
formatted_time = current_time.strftime("%Y_%m_%d_%H_%M_%S") | |
valid_vars['time'] = formatted_time | |
with open(user_input_path, 'a+', encoding="utf8") as file: | |
file.write(json.dumps(valid_vars)+"\n") | |
records = {} | |
with open(user_input_path, 'r', encoding="utf8") as file: | |
for line in file: | |
line = line.strip() | |
dct = json.loads(line) | |
records[dct['time']] = dct | |
return records | |
my_theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty") | |
with gr.Blocks(theme=my_theme, title='Brand_Tone_of_Voice_demo') as demo: | |
gr.HTML( | |
""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<a href="https://github.com/xxx" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;"> | |
</a> | |
<div> | |
<h1 >Place the title of the paper here</h1> | |
<h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;> | |
<a href="https://arxiv.org/abs/xx.xx"><img src="https://img.shields.io/badge/Arxiv-xx.xx-red"></a> | |
<a href='https://huggingface.co/spaces/Oliver12315/Brand_Tone_of_Voice_demo'><img src='https://img.shields.io/badge/Project_Page-Oliver12315/Brand_Tone_of_Voice_demo' alt='Project Page'></a> | |
<a href='https://github.com'><img src='https://img.shields.io/badge/Github-Code-blue'></a> | |
</div> | |
</div> | |
</div> | |
""") | |
with gr.Column(visible=True) as regis: | |
gr.Markdown("# Welcome to BTV! Please fill out the form below to continue.\nI’m assuming that you mention somewhere that this project/research is conducted by the University of Manchester/AMBS. By ticking this box, I consent to be approached by the research team of the University of Manchester.") | |
with gr.Column(variant='panel'): | |
fname_tb = gr.Textbox(label="First Name: ", type='text') | |
lname_tb = gr.Textbox(label="Last Name: ", type='text') | |
email_tb = gr.Textbox(label="Email: ", type='email') | |
cnum_tb = gr.Textbox(label="Contact: (Optional)", type='text') | |
oname_tb = gr.Textbox(label="Organization name: ", type='text') | |
position_tb = gr.Textbox(label="Positions in your company: (Optional)", type='text') | |
error_box = gr.HTML(value="", visible=False) | |
submit_btn = gr.Button("Click here to start if you have fullfill all the item!") | |
with gr.Row(visible=False) as mainrow: | |
with gr.Tab("Single Sentence"): | |
with gr.Row(): | |
tbox_input = gr.Textbox(label="Input", | |
info="Please input a sentence here:") | |
gr.Markdown(""" | |
# Detailed information about our model: | |
... | |
""") | |
tab_output = gr.DataFrame(label='Predictions:', | |
headers=["Label", "Probability"], | |
datatype=["str", "number"], | |
interactive=False) | |
with gr.Row(): | |
button_ss = gr.Button("Submit", variant="primary") | |
button_ss.click(fn=single_sentence, inputs=[tbox_input], outputs=[tab_output]) | |
gr.ClearButton([tbox_input, tab_output]) | |
gr.Examples( | |
examples=examples, | |
inputs=tbox_input, | |
examples_per_page=len(examples) | |
) | |
with gr.Tab("Csv File"): | |
with gr.Row(): | |
csv_input = gr.File(label="CSV File:", | |
file_types=['.csv'], | |
file_count="single" | |
) | |
csv_output = gr.File(label="Predictions:") | |
with gr.Row(): | |
button_cf = gr.Button("Submit", variant="primary") | |
button_cf.click(fn=csv_process, inputs=[csv_input], outputs=[csv_output]) | |
gr.ClearButton([csv_input, csv_output]) | |
gr.Markdown("## Examples \n The incoming CSV must include the ``content`` field, which represents the text that needs to be predicted!") | |
gr.DataFrame(label='Csv input format:', | |
value=[[i, examples[i]] for i in range(len(examples))], | |
headers=["index", "content"], | |
datatype=["number","str"], | |
interactive=False | |
) | |
with gr.Tab("Readme"): | |
gr.Markdown( | |
""" | |
# Paper Name | |
# Authors | |
+ First author | |
+ Corresponding author | |
# Detailed Information | |
... | |
""" | |
) | |
with gr.Tab("Log File"): | |
with gr.Row(): | |
auth_token = gr.Textbox(label="Authentication Tokens: ", info="Enter the key to download persistent stored log information.") | |
log_output = gr.File(label="Log file: ") | |
with gr.Row(): | |
button_lf = gr.Button("Validate", variant="primary") | |
button_lf.click(fn=logfile_query, inputs=[auth_token], outputs=[log_output]) | |
gr.ClearButton([auth_token, log_output]) | |
def submit(*user_input): | |
res = check_save(*user_input) | |
if isinstance(res, list): | |
return { | |
error_box: gr.HTML( | |
value=f""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<div> | |
<p style="color:red;">{"; ".join(res)}</p> | |
</div> | |
</div> | |
""", | |
visible=True) | |
} | |
else: | |
return { | |
mainrow: gr.Row(visible=True), | |
regis: gr.Row(visible=False), | |
error_box: gr.HTML(visible=False) | |
} | |
submit_btn.click( | |
submit, | |
[fname_tb, lname_tb, cnum_tb, email_tb, oname_tb, position_tb], | |
[mainrow, regis, error_box], | |
) | |
demo.launch() |