File size: 1,658 Bytes
54d3b67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f98ee0e
54d3b67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import uvicorn
from fastapi import File
from fastapi import FastAPI
from fastapi import UploadFile
import torch
import os
import sys
import glob
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM


print("Loading models...")
app = FastAPI()

device = "cpu"
correction_model_tag = "prithivida/grammar_error_correcter_v2"
correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag)
correction_model     = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag)

def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

print("Models loaded !")


@app.get("/")
def read_root():
    return {"Gramformer !"}

@app.get("/{correct}")
def get_correction(input_sentence):
    set_seed(1212)
    scored_corrected_sentence = correct(input_sentence)
    return {"scored_corrected_sentence": scored_corrected_sentence}

def correct(input_sentence, max_candidates=1):
    correction_prefix = "gec: "
    input_sentence = correction_prefix + input_sentence
    input_ids = correction_tokenizer.encode(input_sentence, return_tensors='pt')
    input_ids = input_ids.to(device)

    preds = correction_model.generate(
        input_ids,
        do_sample=True, 
        max_length=128, 
        top_k=50, 
        top_p=0.95, 
        early_stopping=True,
        num_return_sequences=max_candidates)

    corrected = set()
    for pred in preds:  
        corrected.add(correction_tokenizer.decode(pred, skip_special_tokens=True).strip())

    corrected = list(corrected)
    return corrected[0], 0  #Corrected Sentence, Dummy score