Gabriel commited on
Commit
6e7afc8
1 Parent(s): 3720ff1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +40 -16
handler.py CHANGED
@@ -1,26 +1,50 @@
1
- from typing import Dict, List, Any
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
 
 
3
  import torch
 
 
 
 
 
4
 
5
- class EndpointHandler():
 
6
  def __init__(self, path=""):
7
- self.processor = TrOCRProcessor.from_pretrained(path)
8
- self.model = VisionEncoderDecoderModel.from_pretrained(path)
 
 
9
 
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- self.model.to(device)
12
-
13
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
  inputs = data.pop("inputs", data)
15
- image_input = inputs.get('image')
16
 
 
 
17
 
18
- # process image
19
- pixel_values = self.processor(images=image_input, return_tensors="pt").pixel_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # run prediction
22
  generated_ids = self.model.generate(pixel_values.to(device))
23
-
24
- # decode output
25
- prediction = generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
26
- return {"text":prediction[0]}
 
 
1
+ import base64
2
+ import io
3
+ from typing import Any, Dict, List
4
+
5
+ import requests
6
  import torch
7
+ from PIL import Image
8
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
9
+
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+
14
+ class EndpointHandler:
15
  def __init__(self, path=""):
16
+ self.processor = TrOCRProcessor.from_pretrained(path)
17
+ self.model = VisionEncoderDecoderModel.from_pretrained(path)
18
+
19
+ self.model.to(device)
20
 
 
 
 
21
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
  inputs = data.pop("inputs", data)
23
+ image_input = inputs.get("image")
24
 
25
+ if not image_input:
26
+ return {"error": "No image provided."}
27
 
28
+ try:
29
+ if image_input.startswith("http"):
30
+ response = requests.get(image_input, stream=True)
31
+ if response.status_code == 200:
32
+ image = Image.open(response.raw).convert("RGB")
33
+ else:
34
+ return {
35
+ "error": f"Failed to fetch image. Status code: {response.status_code}"
36
+ }
37
+ else:
38
+ image_data = base64.b64decode(image_input)
39
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
40
+ except Exception as e:
41
+ return {"error": f"Failed to process the image. Details: {str(e)}"}
42
+
43
+ pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
44
 
 
45
  generated_ids = self.model.generate(pixel_values.to(device))
46
+
47
+ prediction = self.processor.batch_decode(
48
+ generated_ids, skip_special_tokens=True
49
+ )
50
+ return {"text": prediction[0]}