zionia commited on
Commit
78ae48c
1 Parent(s): eaeabb1

update upload function

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -24,7 +24,7 @@ categories = {
24
  def prediction(news):
25
  classifier = pipeline("text-classification", tokenizer=tokenizer, model=model, return_all_scores=True)
26
  preds = classifier(news)
27
- preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
28
  return preds_dict
29
 
30
  def file_prediction(file):
@@ -41,11 +41,11 @@ def file_prediction(file):
41
 
42
  results = []
43
  for news in news_list:
44
- if news.strip():
45
  pred = prediction(news)
46
- results.append([news, pred]) # Return each news and its prediction as a row
47
 
48
- return results # Gradio expects a list of lists or dicts for DataFrame
49
 
50
  gradio_ui = gr.Interface(
51
  fn=prediction,
@@ -69,3 +69,4 @@ gradio_file_ui = gr.Interface(
69
  gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])
70
 
71
  gradio_combined_ui.launch()
 
 
24
  def prediction(news):
25
  classifier = pipeline("text-classification", tokenizer=tokenizer, model=model, return_all_scores=True)
26
  preds = classifier(news)
27
+ preds_dict = {categories.get(pred['label'], pred['label']): round(pred['score'], 4) for pred in preds[0]}
28
  return preds_dict
29
 
30
  def file_prediction(file):
 
41
 
42
  results = []
43
  for news in news_list:
44
+ if news.strip():
45
  pred = prediction(news)
46
+ results.append([news, pred])
47
 
48
+ return results
49
 
50
  gradio_ui = gr.Interface(
51
  fn=prediction,
 
69
  gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])
70
 
71
  gradio_combined_ui.launch()
72
+