dnnsdunca commited on
Commit
439c613
1 Parent(s): 690827e

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +24 -0
predict.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+ import torch
3
+ import json
4
+
5
+ # Load configuration
6
+ with open('../config/config.json') as f:
7
+ config = json.load(f)
8
+
9
+ # Load model and tokenizer
10
+ model = AutoModelForSequenceClassification.from_pretrained('../model')
11
+ tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
12
+
13
+ def predict(text):
14
+ inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True)
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+ logits = outputs.logits
18
+ prediction = torch.argmax(logits, dim=-1)
19
+ return prediction.item()
20
+
21
+ # Example usage
22
+ text = "Example text for prediction"
23
+ prediction = predict(text)
24
+ print(f"Prediction: {prediction}")