kix-intl's picture
Update app.py
61f7f1f verified
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import gradio as gr
# モデルとトークナイザーのロード
model_name = "kix-intl/elon-musk-detector" # あなたのモデル名に置き換えてください
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def classify_tweet(tweet):
inputs = tokenizer(tweet, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
prediction = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][prediction].item()
if prediction == 1:
return f"Elon Musk (Confidence: {confidence:.2f})"
else:
return f"Not Elon Musk (Confidence: {confidence:.2f})"
# Gradioインターフェースの作成
iface = gr.Interface(
fn=classify_tweet,
inputs=gr.Textbox(lines=2, placeholder="Enter a tweet here..."),
outputs=gr.Textbox(),
title="Elon Musk Tweet Classifier",
description="This model classifies whether a given tweet is likely to be written by Elon Musk or not.",
examples=[
["Tesla's new model is amazing!"],
["Just had a great coffee at my favorite local café."],
["Sending astronauts to Mars is the next big challenge for humanity."]
]
)
# Hugging Faceにデプロイする場合は、以下の行のコメントを解除してください
iface.launch()