Commit
•
d90a235
1
Parent(s):
f364bfe
Update handler.py
Browse files- handler.py +21 -11
handler.py
CHANGED
@@ -9,16 +9,26 @@ class EndpointHandler:
|
|
9 |
).to("cuda")
|
10 |
|
11 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
prediction = self.pipeline(
|
23 |
prompt,
|
24 |
negative_prompt=negative_prompt,
|
@@ -26,6 +36,6 @@ class EndpointHandler:
|
|
26 |
width=width,
|
27 |
guidance_scale=guidance_scale,
|
28 |
num_inference_steps=num_inference_steps,
|
29 |
-
generator=
|
30 |
).images[0]
|
31 |
return prediction
|
|
|
9 |
).to("cuda")
|
10 |
|
11 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
12 |
+
# Extract data
|
13 |
+
data = data.get("json", data)
|
14 |
+
prompt = data.get("inputs", None)
|
15 |
+
parameters = data.get("parameters", {})
|
16 |
+
if not prompt:
|
17 |
+
raise ValueError("Input prompt is missing.")
|
18 |
+
|
19 |
+
|
20 |
+
# Extract parameters with defaults
|
21 |
+
negative_prompt = parameters.get("negative_prompt", "bad quality, worse quality, deformed")
|
22 |
+
height = parameters.get("height", 512)
|
23 |
+
width = parameters.get("width", 512)
|
24 |
+
guidance_scale = parameters.get("guidance_scale", 4.5)
|
25 |
+
num_inference_steps = parameters.get("num_inference_steps", 28)
|
26 |
+
seed = parameters.get("seed", 0)
|
27 |
+
|
28 |
+
# Seed generator
|
29 |
+
generator = torch.manual_seed(seed)
|
30 |
+
|
31 |
+
# Generate prediction
|
32 |
prediction = self.pipeline(
|
33 |
prompt,
|
34 |
negative_prompt=negative_prompt,
|
|
|
36 |
width=width,
|
37 |
guidance_scale=guidance_scale,
|
38 |
num_inference_steps=num_inference_steps,
|
39 |
+
generator=generator
|
40 |
).images[0]
|
41 |
return prediction
|