aa1223 commited on
Commit
d1fced8
1 Parent(s): afe3ec0

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +63 -0
main.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline, BitsAndBytesConfig
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+ # Set up device (CPU or GPU)
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # Configure quantization if using GPU
13
+ if device == "cuda":
14
+ print("GPU found. Using 4-bit quantization.")
15
+ quantization_config = BitsAndBytesConfig(
16
+ load_in_4bit=True,
17
+ bnb_4bit_compute_dtype=torch.bfloat16
18
+ )
19
+ else:
20
+ print("GPU not found. Using CPU with default settings.")
21
+ quantization_config = None
22
+
23
+ # Load model pipeline
24
+ model_id = "bczhou/tiny-llava-v1-hf"
25
+ pipe = pipeline("image-to-text", model=model_id, device=device)
26
+
27
+ print(f"Using device: {device}")
28
+
29
+ # Initialize FastAPI application
30
+ app = FastAPI()
31
+
32
+ # Health check endpoint to ensure API is running
33
+ @app.get("/")
34
+ async def root():
35
+ return {"message": "API is running fine."}
36
+
37
+ # Define Pydantic model for request input
38
+ class ImagePromptInput(BaseModel):
39
+ image_url: str
40
+ prompt: str
41
+
42
+ # FastAPI route for generating text from an image
43
+ @app.post("/generate")
44
+ async def generate_text(input_data: ImagePromptInput):
45
+ try:
46
+ # Download and process the image
47
+ response = requests.get(input_data.image_url)
48
+ image = Image.open(BytesIO(response.content)).convert("RGB")
49
+ image = image.resize((750, 500)) # Resize image to fixed dimensions
50
+
51
+ # Create a full prompt to pass to the model
52
+ full_prompt = f"USER: <image>\n{input_data.prompt}\nASSISTANT: "
53
+
54
+ # Generate response using the model pipeline
55
+ outputs = pipe(image, prompt=full_prompt, generate_kwargs={"max_new_tokens": 200})
56
+
57
+ # Return generated text
58
+ generated_text = outputs[0]['generated_text'] #type: ignore
59
+ return {"response": generated_text}
60
+
61
+ except Exception as e:
62
+ # Return error if something goes wrong
63
+ raise HTTPException(status_code=500, detail=str(e))