Spaces:
Running
Running
File size: 3,082 Bytes
885b434 1dd5bbf 885b434 5a63293 cb34ab7 5a63293 885b434 4c9facd 885b434 b9bec37 1dd5bbf b9bec37 885b434 1dd5bbf 6b23fac 1dd5bbf 6b23fac 1dd5bbf 6b23fac 1dd5bbf 885b434 5a63293 4c9facd 4ecebd0 1dd5bbf 885b434 1dd5bbf 4ecebd0 6b23fac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import pandas as pd
MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News")
model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News")
categories = {
"arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang",
"crime_law_and_justice": "Bosenyi, molao le bosiamisi",
"disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso",
"economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete",
"education": "Thuto",
"environment": "Tikologo",
"health": "Boitekanelo",
"politics": "Dipolotiki",
"religion_and_belief": "Bodumedi le tumelo",
"society": "Setšhaba"
}
def prediction(news):
clasifer = pipeline("sentiment-analysis", tokenizer=tokenizer, model=model, return_all_scores=True)
preds = clasifer(news)
preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
return preds_dict
def file_prediction(file):
# Load the file (CSV or text)
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
news_list = df.iloc[:, 0].tolist()
else:
news_list = [file.read().decode('utf-8')]
results = []
for news in news_list:
results.append(prediction(news))
return results
gradio_ui = gr.Interface(
fn=prediction,
title="Setswana News Classification",
description=f"Enter Setswana news article to see the category of the news.\n For this classification, the {MODEL_URL} model was used.",
inputs=gr.Textbox(lines=10, label="Paste some Setswana news here"),
outputs=gr.Label(num_top_classes=5, label="News categories probabilities"),
)
gradio_file_ui = gr.Interface(
fn=file_prediction,
title="Upload File for Setswana News Classification",
description=f"Upload a text or CSV file with Setswana news articles. The first column in the CSV should contain the news text.",
inputs=gr.File(label="Upload text or CSV file"),
outputs=gr.Dataframe(headers=["News Text", "Category Predictions"], label="Predictions from file"),
)
gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])
css = """
body {
background-color: white !important;
color: black !important;
}
.gradio-container {
background-color: white !important;
color: black !important;
}
.gr-input, .gr-button, .gr-textbox, .gr-file, .gr-dataframe {
background-color: white !important;
color: black !important;
border-color: #ccc !important;
}
.gr-button {
background-color: #f0f0f0 !important;
color: black !important;
border: 1px solid #ccc !important;
}
.gr-dataframe th, .gr-dataframe td {
background-color: #f9f9f9 !important;
color: black !important;
}
"""
gradio_combined_ui.launch(css=css)
|