File size: 1,493 Bytes
5534b61
3097d1f
 
 
 
92758a0
3097d1f
e2a8021
 
 
 
 
5534b61
 
 
e2a8021
 
 
 
 
 
5534b61
 
 
ac9d7ff
3097d1f
92758a0
e2a8021
3097d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch
from pydantic import BaseModel, Field

class RequestGenerate(BaseModel):
    prompt: str 
    do_sample: bool = Field(default=bool(True), example=True)
    top_k: int = Field(default=int(1), example=1),
    temperature: float = Field(default=float(0.9), example=0.9),
    max_new_tokens: int = Field(default=int(500), example=500),
    repetition_penalty: float = Field(default=float(1.5), example=1.5),

app = FastAPI()

# model_name_or_id = "AI4Chem/ChemLLM-7B-Chat"
model_name_or_id = "AI4Chem/CHEMLLM-2b-1_5"

model = AutoModelForCausalLM.from_pretrained(model_name_or_id,trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_id,trust_remote_code=True)

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.post("/generate")
def generate(req: RequestGenerate):
    

    inputs = tokenizer(req.prompt, return_tensors="pt")

    generation_config = GenerationConfig(
        do_sample=req.do_sample,
        top_k=req.top_k,
        temperature=req.temperature,
        max_new_tokens=req.max_new_tokens,
        repetition_penalty=req.repetition_penalty,
        pad_token_id=tokenizer.eos_token_id
    )

    outputs = model.generate(**inputs, generation_config=generation_config)
    # print(tokenizer.decode(outputs[0], skip_special_tokens=True))

    return {"text": tokenizer.decode(outputs[0], skip_special_tokens=True)}