|
import torch |
|
from model import build_transformer |
|
from train import greedy_decode, get_model, get_or_build_tokenizer |
|
|
|
from config import get_config, get_weights_file_path |
|
from tokenizers import Tokenizer |
|
from pathlib import Path |
|
|
|
|
|
|
|
config = get_config() |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
def process_text(config, src_text, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len): |
|
seq_len = seq_len |
|
|
|
|
|
tokenizer_src = tokenizer_src |
|
tokenizer_tgt = tokenizer_tgt |
|
src_lang = src_lang |
|
tgt_lang = tgt_lang |
|
|
|
sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64) |
|
eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64) |
|
pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64) |
|
|
|
enc_input_tokens = tokenizer_src.encode(src_text).ids |
|
|
|
|
|
|
|
enc_num_padding_tokens = seq_len - len(enc_input_tokens) - 2 |
|
|
|
|
|
|
|
|
|
if enc_num_padding_tokens < 0: |
|
raise ValueError("Sentence is too long") |
|
|
|
|
|
encoder_input = torch.cat( |
|
[ |
|
sos_token, |
|
torch.tensor(enc_input_tokens, dtype=torch.int64), |
|
eos_token, |
|
torch.tensor([pad_token] * enc_num_padding_tokens, dtype=torch.int64), |
|
], |
|
dim=0, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert encoder_input.size(0) == seq_len |
|
|
|
|
|
return { |
|
'encoder_input': encoder_input, |
|
|
|
"encoder_mask": (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).int(), |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
def causal_mask(size): |
|
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) |
|
return mask == 0 |
|
|
|
def infer(text, config): |
|
tokenizer_src = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_src'])))) |
|
tokenizer_tgt = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_tgt'])))) |
|
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()) |
|
state = torch.load('tmodel_36.pt', map_location=torch.device('cpu')) |
|
model.load_state_dict(state['model_state_dict']) |
|
|
|
|
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
processed_text = process_text(config, text, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) |
|
encoder_input = processed_text['encoder_input'] |
|
encoder_mask = processed_text['encoder_mask'] |
|
|
|
model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device) |
|
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) |
|
return model_out_text |
|
|
|
|
|
import streamlit as st |
|
|
|
st.title("English to Hausa Translator") |
|
|
|
user_input = st.text_input("Enter your text:") |
|
if user_input: |
|
result = infer(user_input, config) |
|
st.write("Inference Result:", result) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|