Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
import pandas as pd | |
import random | |
classifiers = ['toxic', 'severe_toxic', 'obscene', | |
'threat', 'insult', 'identity_hate'] | |
def reset_scores(): | |
global scores_df | |
scores_df = pd.DataFrame(columns=['Comment'] + classifiers) | |
def get_score(model_base, text): | |
if model_base == "bert-base-cased": | |
model_dir = "./bert/_bert_model" | |
elif model_base == "distilbert-base-cased": | |
model_dir = "./distilbert/_distilbert_model" | |
else: | |
model_dir = "./roberta/_roberta_model" | |
model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
tokenizer = AutoTokenizer.from_pretrained(model_base) | |
inputs = tokenizer.encode_plus( | |
text, max_length=512, truncation=True, padding=True, return_tensors='pt') | |
outputs = model(**inputs) | |
predictions = torch.sigmoid(outputs.logits) | |
return predictions | |
# Ask user for input, return scores | |
st.title("Toxic Comment Classifier") | |
st.write("John Makely") | |
st.write("The following model's are fine tuned on the jigsaw-toxic-comment-classification dataset") | |
st.write("Please be patient and give the queries and tables time to load (max 2 minutes)") | |
# Drop down menu for model selection, default is roberta | |
model_base = st.selectbox("Select a pretrained model", | |
["roberta-base", "bert-base-cased", "distilbert-base-cased"]) | |
text_input = st.text_input("Enter text for toxicity classification", | |
"I hope you die") | |
st.write("After hitting Submit, classification scores will be displayed for the provided text") | |
submit_btn = st.button("Submit") | |
if submit_btn and text_input: | |
result = get_score(model_base, text_input) | |
df = pd.DataFrame([result[0].tolist()], columns=classifiers) | |
df = df.round(2) # Round the values to 2 decimal places | |
# Format the values as percentages | |
df = df.applymap(lambda x: '{:.0%}'.format(x)) | |
st.table(df) | |
# Read the test dataset | |
test_df = pd.read_csv( | |
"./jigsaw-toxic-comment-classification-challenge/test.csv") | |
# Select 10 random comments from the test dataset | |
sample_df = test_df.sample(n=3) | |
# Create an empty DataFrame to store the scores | |
reset_scores() | |
# Calculate the scores for each comment and add them to the DataFrame | |
for index, row in sample_df.iterrows(): | |
result = get_score(model_base, row['comment_text']) | |
scores = result[0].tolist() | |
scores_df.loc[len(scores_df)] = [row['comment_text']] + scores | |
# Round the values to 2 decimal places | |
scores_df = scores_df.round(2) | |
st.subheader("Toxicity Scores for Random Comments") | |
st.write("The following table will grab random values from the jigsaw dataset and display their respective scores") | |
st.write("Please be patient as it may take some time for the scores to be passed through the model") | |
# Create a button to reset the scores | |
if st.button("Refresh Random Tweets"): | |
reset_scores() | |
st.success("New tweets have been loaded!") | |
st.table(scores_df) | |