Modeltest / app.py
CosmoAI's picture
Update app.py
8fb1275 verified
raw
history blame
2.82 kB
import torch
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import nltk
from nltk.stem.porter import PorterStemmer
import json
import numpy as np
import random
import streamlit as st
nltk.download('punkt')
def ExecuteQuery(query):
class NeuralNet(nn.Module):
def __init__(self,input_size,hidden_size,num_classes):
super(NeuralNet,self).__init__()
self.l1 = nn.Linear(input_size,hidden_size)
self.l2 = nn.Linear(hidden_size,hidden_size)
self.l3 = nn.Linear(hidden_size,num_classes)
self.relu = nn.ReLU()
def forward(self,x):
out = self.l1(x)
out = self.relu(out)
out = self.l2(out)
out = self.relu(out)
out = self.l3(out)
return out
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with open('files/intents.json', 'r') as json_data:
intents = json.load(json_data)
FILE = "files/intents.pth"
data = torch.load(FILE)
# with open('Data/Tasks.pth') as f:
# data = torch.load(f)
input_size = data["input_size"]
hidden_size = data["hidden_size"]
output_size = data["output_size"]
all_words = data["all_words"]
tags = data["tags"]
model_state = data["model_state"]
model = NeuralNet(input_size,hidden_size,output_size).to(device)
model.load_state_dict(model_state)
model.eval()
Stemmer = PorterStemmer()
def tokenize(sentence):
return nltk.word_tokenize(sentence)
def stem(word):
return Stemmer.stem(word.lower())
def bag_of_words(tokenized_sentence,words):
sentence_word = [stem(word) for word in tokenized_sentence]
bag = np.zeros(len(words),dtype=np.float32)
for idx , w in enumerate(words):
if w in sentence_word:
bag[idx] = 1
return bag
sentence = str(query)
sentence = tokenize(sentence)
X = bag_of_words(sentence,all_words)
X = X.reshape(1,X.shape[0])
X = torch.from_numpy(X).to(device)
output = model(X)
_ , predicted = torch.max(output,dim=1)
tag = tags[predicted.item()]
probs = torch.softmax(output,dim=1)
prob = probs[0][predicted.item()]
if prob.item() >= 0.96:
for intent in intents['intents']:
if tag == intent["tag"]:
reply = random.choice(intent["responses"])
return reply, tag, prob.item()
if prob.item() <= 0.95:
reply = "opencosmo"
tag = "opencosmo"
return reply, tag, prob.item()
if query := st.text_input("Enter your query: "):
reply = ExecuteQuery(query)
st.write(reply[0])
print(f"Tag: {reply[1]}")
print(f"Prob: {reply[2]}")