File size: 3,082 Bytes
885b434
 
1dd5bbf
885b434
5a63293
cb34ab7
 
5a63293
 
885b434
4c9facd
 
 
 
 
 
 
 
 
 
 
 
 
885b434
b9bec37
 
1dd5bbf
b9bec37
885b434
1dd5bbf
6b23fac
1dd5bbf
6b23fac
 
1dd5bbf
6b23fac
1dd5bbf
 
 
 
 
 
 
885b434
 
5a63293
4c9facd
4ecebd0
 
1dd5bbf
 
 
 
 
 
 
 
885b434
 
1dd5bbf
4ecebd0
6b23fac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import pandas as pd

MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"

tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News")
model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News")

categories = {
    "arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang",
    "crime_law_and_justice": "Bosenyi, molao le bosiamisi",
    "disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso",
    "economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete",
    "education": "Thuto",
    "environment": "Tikologo",
    "health": "Boitekanelo",
    "politics": "Dipolotiki",
    "religion_and_belief": "Bodumedi le tumelo",
    "society": "Setšhaba"
}

def prediction(news):
    clasifer = pipeline("sentiment-analysis", tokenizer=tokenizer, model=model, return_all_scores=True)
    preds = clasifer(news)
    preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
    return preds_dict

def file_prediction(file):
    # Load the file (CSV or text)
    if file.name.endswith('.csv'):
        df = pd.read_csv(file.name)  
        news_list = df.iloc[:, 0].tolist()
    else:
        news_list = [file.read().decode('utf-8')] 
    
    results = []
    for news in news_list:
        results.append(prediction(news))
    
    return results

gradio_ui = gr.Interface(
    fn=prediction,
    title="Setswana News Classification",
    description=f"Enter Setswana news article to see the category of the news.\n For this classification, the {MODEL_URL} model was used.",
    inputs=gr.Textbox(lines=10, label="Paste some Setswana news here"),
    outputs=gr.Label(num_top_classes=5, label="News categories probabilities"),
)

gradio_file_ui = gr.Interface(
    fn=file_prediction,
    title="Upload File for Setswana News Classification",
    description=f"Upload a text or CSV file with Setswana news articles. The first column in the CSV should contain the news text.",
    inputs=gr.File(label="Upload text or CSV file"),
    outputs=gr.Dataframe(headers=["News Text", "Category Predictions"], label="Predictions from file"),
)

gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])

css = """
body {
    background-color: white !important;
    color: black !important;
}

.gradio-container {
    background-color: white !important;
    color: black !important;
}

.gr-input, .gr-button, .gr-textbox, .gr-file, .gr-dataframe {
    background-color: white !important;
    color: black !important;
    border-color: #ccc !important;
}

.gr-button {
    background-color: #f0f0f0 !important;
    color: black !important;
    border: 1px solid #ccc !important;
}

.gr-dataframe th, .gr-dataframe td {
    background-color: #f9f9f9 !important;
    color: black !important;
}
"""

gradio_combined_ui.launch(css=css)