File size: 1,181 Bytes
9ff73a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import transformers
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForSequenceClassification
import gradio as gr

# Load the pre-trained tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained('./model_save/',local_files_only=True)

# Load the pre-trained DilBERT model
model = DistilBertForSequenceClassification.from_pretrained('./model_save/',local_files_only=True)
model.eval()

# Define a predict function
def predict(text):
    encoding=tokenizer(text,return_tensors='pt')
    input_ids, attention_mask = encoding['input_ids'],encoding['attention_mask']
    outputs = model(input_ids,attention_mask=attention_mask)
    logits = outputs['logits']
    pred_label = torch.argmax(logits,1)[0]
    return 'Positive' if pred_label > 0.5 else 'Negative'

# Initialize the Gradio interface
title = "Write a movie review"
description = "Enter a review for a movie you've seen.  This tool will try to guess whether your review is positive or negative."
gr.Interface(fn=predict, 
             inputs="text",
             outputs="label",
             title = title,
             description = description,
              ).launch()