virginie-d commited on
Commit
380cd6b
1 Parent(s): aba07ec

First batch of updates

Browse files
Files changed (2) hide show
  1. handler.py +43 -1
  2. requirements.txt +10 -0
handler.py CHANGED
@@ -1 +1,43 @@
1
- # test
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ import torch, PIL, transformers, triton, sentencepiece, protobuf
4
+ import torchvision, einops
5
+ import xformers, accelerate
6
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
7
+
8
+
9
+ class EndpointHandler():
10
+ def __init__(self, path=""):
11
+ self.model = AutoModelForCausalLM.from_pretrained(
12
+ 'THUDM/cogvlm-chat-hf',
13
+ torch_dtype=torch.bfloat16,
14
+ low_cpu_mem_usage=True,
15
+ trust_remote_code=True,
16
+ # cache_dir='/tmp'
17
+ )
18
+ self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
19
+ # create inference pipeline
20
+ # self.pipeline = pipeline(model=model, tokenizer=tokenizer)
21
+
22
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
23
+ """
24
+ Args:
25
+ data (:obj:):
26
+ includes the input data and the parameters for the inference.
27
+ Return:
28
+ A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
29
+ - "label": A string representing what the label/class is. There can be multiple labels.
30
+ - "score": A score between 0 and 1 describing how confident the model is for this label/class.
31
+ """
32
+ inputs = data.pop("inputs", data)
33
+ gen_kwargs = {"max_length": 2048, "do_sample": False}
34
+
35
+ # pass inputs with all kwargs in data
36
+ # prediction = self.pipeline(inputs)
37
+
38
+ outputs = self.model.generate(**inputs, **gen_kwargs)
39
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
40
+ prediction = self.tokenizer.decode(outputs[0])
41
+
42
+ # post process the prediction
43
+ return prediction
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ Pillow
3
+ transformers
4
+ triton
5
+ sentencepiece
6
+ protobuf
7
+ torchvision
8
+ einops
9
+ xformers
10
+ accelerate