Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
class TwitterEmotionClassifier: | |
def __init__(self, model_name: str, model_type: str): | |
self.is_gpu = False | |
self.model_type = model_type | |
device = torch.device("cuda") if self.is_gpu else torch.device("cpu") | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model.to(device) | |
model.eval() | |
self.bertweet = pipeline( | |
"text-classification", | |
model=model, | |
tokenizer=tokenizer, | |
device=self.is_gpu - 1, | |
) | |
self.deberta = None | |
self.emotions = { | |
"LABEL_0": "sadness", | |
"LABEL_1": "joy", | |
"LABEL_2": "love", | |
"LABEL_3": "anger", | |
"LABEL_4": "fear", | |
"LABEL_5": "surprise", | |
} | |
def get_model(self, model_type: str): | |
if self.model_type == "bertweet" and model_type == self.model_type: | |
return self.bertweet | |
elif model_type == "deberta": | |
if self.deberta: | |
return self.deberta | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"Emanuel/twitter-emotion-deberta-v3-base" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"Emanuel/twitter-emotion-deberta-v3-base" | |
) | |
self.deberta = pipeline( | |
"text-classification", | |
model=model, | |
tokenizer=tokenizer, | |
device=self.is_gpu - 1, | |
) | |
return self.deberta | |
def predict(self, twitter: str, model_type: str): | |
classifier = self.get_model(model_type) | |
preds = classifier(twitter, return_all_scores=True) | |
if preds: | |
pred = preds[0] | |
res = { | |
"Sadness ๐ข": pred[0]["score"], | |
"Joy ๐": pred[1]["score"], | |
"Love ๐": pred[2]["score"], | |
"Anger ๐ ": pred[3]["score"], | |
"Fear ๐ฑ": pred[4]["score"], | |
"Surprise ๐ฎ": pred[5]["score"], | |
} | |
return res | |
return None | |
def main(): | |
model = TwitterEmotionClassifier("Emanuel/bertweet-emotion-base", "bertweet") | |
interFace = gr.Interface( | |
fn=model.predict, | |
inputs=[ | |
gr.inputs.Textbox( | |
placeholder="What's happenning?", label="Tweet content", lines=5 | |
), | |
gr.inputs.Radio(["bertweet", "deberta"], label="Model"), | |
], | |
outputs=gr.outputs.Label(num_top_classes=6, label="Emotions of this tweet is "), | |
verbose=True, | |
examples=[ | |
["Tesla Bot is truly amazing. It's the early steps of a revolution in the role that AI & robots play in human civilization. What the Tesla team was been able to accomplish in the last few months is just incredible. As someone who loves AI and robotics, I'm inspired beyond words.", "bertweet"], | |
[ | |
"I got food poisoning. It sucks ๐ฅต but it makes me appreciate: 1. the days when I'm not sick and 2. just how damn incredible the human body is at fighting off all the things that try to kill it. Biology is awesome. Life is awesome.", | |
"bertweet", | |
], | |
["I'm adding human-created captions to many podcasts soon. (It's expensive ๐) These identify the speaker, are timed to the audio, and so make for good training data. When you and I do a podcast, we too will become immortalized as training data.", "bertweet"], | |
[ | |
"We live inside a simulation and are ourselves creating progressively more realistic and interesting simulations. Existence is a recursive simulation generator.", | |
"bertweet", | |
], | |
["Here's my conversation with Will Sasso, one of the funniest people on the planet and someone who I've been a fan of for over 20 years. https://youtube.com/watch?v=xewD1apJNhw PS: His @Twitter account @WillSasso got hacked yesterday. @TwitterSupport please help him out!", "deberta"], | |
], | |
title="Emotion classification ๐ค", | |
description="", | |
theme="huggingface", | |
) | |
interFace.launch() | |
if __name__ == "__main__": | |
main() |