jjmakes's picture
Update app.py
3b57e27
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import pandas as pd
import random
classifiers = ['toxic', 'severe_toxic', 'obscene',
'threat', 'insult', 'identity_hate']
def reset_scores():
global scores_df
scores_df = pd.DataFrame(columns=['Comment'] + classifiers)
def get_score(model_base, text):
if model_base == "bert-base-cased":
model_dir = "./bert/_bert_model"
elif model_base == "distilbert-base-cased":
model_dir = "./distilbert/_distilbert_model"
else:
model_dir = "./roberta/_roberta_model"
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_base)
inputs = tokenizer.encode_plus(
text, max_length=512, truncation=True, padding=True, return_tensors='pt')
outputs = model(**inputs)
predictions = torch.sigmoid(outputs.logits)
return predictions
# Ask user for input, return scores
st.title("Toxic Comment Classifier")
st.write("John Makely")
st.write("The following model's are fine tuned on the jigsaw-toxic-comment-classification dataset")
st.write("Please be patient and give the queries and tables time to load (max 2 minutes)")
# Drop down menu for model selection, default is roberta
model_base = st.selectbox("Select a pretrained model",
["roberta-base", "bert-base-cased", "distilbert-base-cased"])
text_input = st.text_input("Enter text for toxicity classification",
"I hope you die")
st.write("After hitting Submit, classification scores will be displayed for the provided text")
submit_btn = st.button("Submit")
if submit_btn and text_input:
result = get_score(model_base, text_input)
df = pd.DataFrame([result[0].tolist()], columns=classifiers)
df = df.round(2) # Round the values to 2 decimal places
# Format the values as percentages
df = df.applymap(lambda x: '{:.0%}'.format(x))
st.table(df)
# Read the test dataset
test_df = pd.read_csv(
"./jigsaw-toxic-comment-classification-challenge/test.csv")
# Select 10 random comments from the test dataset
sample_df = test_df.sample(n=3)
# Create an empty DataFrame to store the scores
reset_scores()
# Calculate the scores for each comment and add them to the DataFrame
for index, row in sample_df.iterrows():
result = get_score(model_base, row['comment_text'])
scores = result[0].tolist()
scores_df.loc[len(scores_df)] = [row['comment_text']] + scores
# Round the values to 2 decimal places
scores_df = scores_df.round(2)
st.subheader("Toxicity Scores for Random Comments")
st.write("The following table will grab random values from the jigsaw dataset and display their respective scores")
st.write("Please be patient as it may take some time for the scores to be passed through the model")
# Create a button to reset the scores
if st.button("Refresh Random Tweets"):
reset_scores()
st.success("New tweets have been loaded!")
st.table(scores_df)