BennyKok commited on
Commit
182af0c
1 Parent(s): 97369e8

feat: add ws demo

Browse files
.env.example CHANGED
@@ -3,3 +3,5 @@ COMFY_API_TOKEN=""
3
  COMFY_DEPLOYMENT_ID=""
4
  COMFY_DEPLOYMENT_ID_IMG_2_IMG=""
5
  COMFY_DEPLOYMENT_ID_CONTROLNET=""
 
 
 
3
  COMFY_DEPLOYMENT_ID=""
4
  COMFY_DEPLOYMENT_ID_IMG_2_IMG=""
5
  COMFY_DEPLOYMENT_ID_CONTROLNET=""
6
+
7
+ COMFY_DEPLOYMENT_WS=""
bun.lockb CHANGED
Binary files a/bun.lockb and b/bun.lockb differ
 
package.json CHANGED
@@ -17,14 +17,17 @@
17
  "@radix-ui/react-tabs": "^1.0.4",
18
  "class-variance-authority": "^0.7.0",
19
  "clsx": "^2.1.0",
20
- "comfydeploy": "^0.0.11",
21
  "lucide-react": "^0.309.0",
22
  "next": "14.0.3",
23
  "react": "^18",
24
  "react-dom": "^18",
25
  "react-hook-form": "^7.49.3",
 
 
26
  "tailwind-merge": "^2.2.0",
27
  "tailwindcss-animate": "^1.0.7",
 
28
  "zod": "^3.22.4"
29
  },
30
  "devDependencies": {
 
17
  "@radix-ui/react-tabs": "^1.0.4",
18
  "class-variance-authority": "^0.7.0",
19
  "clsx": "^2.1.0",
20
+ "comfydeploy": "0.0.12",
21
  "lucide-react": "^0.309.0",
22
  "next": "14.0.3",
23
  "react": "^18",
24
  "react-dom": "^18",
25
  "react-hook-form": "^7.49.3",
26
+ "react-use-websocket": "^4.7.0",
27
+ "swr": "^2.2.5",
28
  "tailwind-merge": "^2.2.0",
29
  "tailwindcss-animate": "^1.0.7",
30
+ "use-debounce": "^10.0.0",
31
  "zod": "^3.22.4"
32
  },
33
  "devDependencies": {
src/app/page.tsx CHANGED
@@ -26,16 +26,21 @@ import {
26
 
27
  import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
28
  import { ImageGenerationResult } from "@/components/ImageGenerationResult";
 
29
 
30
  export default function Page() {
31
  return (
32
  <main className="flex min-h-screen flex-col items-center justify-between mt-10">
33
- <Tabs defaultValue="txt2img" className="w-full max-w-[600px]">
34
- <TabsList className="grid w-full grid-cols-3">
 
35
  <TabsTrigger value="txt2img">txt2img</TabsTrigger>
36
  <TabsTrigger value="img2img">img2img</TabsTrigger>
37
  <TabsTrigger value="controlpose">Controlpose</TabsTrigger>
38
  </TabsList>
 
 
 
39
  <TabsContent value="txt2img">
40
  <Txt2img />
41
  </TabsContent>
 
26
 
27
  import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
28
  import { ImageGenerationResult } from "@/components/ImageGenerationResult";
29
+ import { WebsocketDemo } from "@/components/WebsocketDemo";
30
 
31
  export default function Page() {
32
  return (
33
  <main className="flex min-h-screen flex-col items-center justify-between mt-10">
34
+ <Tabs defaultValue="ws" className="w-full max-w-[600px]">
35
+ <TabsList className="grid w-full grid-cols-4">
36
+ <TabsTrigger value="ws">Realtime</TabsTrigger>
37
  <TabsTrigger value="txt2img">txt2img</TabsTrigger>
38
  <TabsTrigger value="img2img">img2img</TabsTrigger>
39
  <TabsTrigger value="controlpose">Controlpose</TabsTrigger>
40
  </TabsList>
41
+ <TabsContent value="ws">
42
+ <WebsocketDemo />
43
+ </TabsContent>
44
  <TabsContent value="txt2img">
45
  <Txt2img />
46
  </TabsContent>
src/components/WebsocketDemo.tsx ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "use client"
2
+
3
+ import { getWebsocketUrl } from '@/server/generate'
4
+ import { useCallback, useEffect, useRef, useState } from 'react'
5
+ import useSWR from 'swr'
6
+ import useWebSocket, { ReadyState } from 'react-use-websocket';
7
+ import { useDebounce } from "use-debounce";
8
+ import { Input } from './ui/input';
9
+ import { Badge } from './ui/badge';
10
+
11
+ export function WebsocketDemo() {
12
+ const { data } = useSWR("ws", getWebsocketUrl, {
13
+ revalidateOnFocus: false,
14
+ })
15
+ const [ws, setWs] = useState<WebSocket>()
16
+
17
+ const [status, setStatus] = useState("not-connected")
18
+ const [prompt, setPrompt] = useState('A anime cat');
19
+ const [debouncedPrompt] = useDebounce(prompt, 200);
20
+
21
+ const [currentLog, setCurrentLog] = useState<string>();
22
+
23
+ const [reconnectCounter, setReconnectCounter] = useState(0)
24
+
25
+ const canvasRef = useRef<HTMLCanvasElement>(null); // Reference to the canvas element
26
+
27
+ const sendInput = useCallback(() => {
28
+ if (status == "reconnecting")
29
+ return
30
+
31
+ if (ws?.readyState == ws?.CLOSED) {
32
+ setStatus('reconnecting')
33
+ setReconnectCounter(x => x + 1)
34
+ return
35
+ }
36
+
37
+ if (status != "ready")
38
+ return
39
+
40
+ ws?.send(JSON.stringify(
41
+ {
42
+ "event": "input",
43
+ "inputs": {
44
+ "input_text": debouncedPrompt
45
+ }
46
+ }
47
+ ))
48
+ }, [ws, debouncedPrompt, status])
49
+
50
+ const preStatus = useRef(status)
51
+
52
+ useEffect(() => {
53
+ if (preStatus.current != status && status == "ready")
54
+ sendInput();
55
+ preStatus.current = status
56
+ }, [status, sendInput])
57
+
58
+ useEffect(() => {
59
+ sendInput();
60
+ }, [debouncedPrompt])
61
+
62
+ const connectWS = useCallback((data: NonNullable<Awaited<ReturnType<typeof getWebsocketUrl>>>) => {
63
+ const websocket = new WebSocket(data.ws_connection_url);
64
+ websocket.binaryType = "arraybuffer";
65
+ websocket.onopen = () => {
66
+ setStatus("connected");
67
+ };
68
+ websocket.onmessage = (event) => {
69
+ if (typeof event.data === "string") {
70
+ const message = JSON.parse(event.data);
71
+ if (message?.event == "status" && message?.data?.sid) {
72
+ setStatus("ready");
73
+ }
74
+ if (message?.event) {
75
+ if (message?.event == "executing" && message?.data?.node == null)
76
+ setCurrentLog("done")
77
+ else if (message?.event == "live_status")
78
+ setCurrentLog(`running - ${message.data?.current_node} ${(message.data.progress * 100).toFixed(2)}%`)
79
+ else if (message?.event == "elapsed_time")
80
+ setCurrentLog(`elapsed time: ${Math.ceil(message.data?.elapsed_time * 100) / 100}s`)
81
+ }
82
+ console.log("Received message:", message);
83
+ }
84
+ if (event.data instanceof ArrayBuffer) {
85
+ console.log("Received binary message:");
86
+ drawImage(event.data);
87
+ }
88
+ };
89
+ websocket.onclose = () => setStatus("closed");
90
+ websocket.onerror = () => setStatus("error");
91
+
92
+ setWs(websocket);
93
+
94
+ return () => {
95
+ websocket.close();
96
+ };
97
+ }, [data])
98
+
99
+ const drawImage = useCallback((arrayBuffer: ArrayBuffer) => {
100
+ const view = new DataView(arrayBuffer);
101
+ const eventType = view.getUint32(0);
102
+ const buffer = arrayBuffer.slice(4);
103
+ switch (eventType) {
104
+ case 1:
105
+ const view2 = new DataView(arrayBuffer);
106
+ const imageType = view2.getUint32(0)
107
+ let imageMime
108
+ switch (imageType) {
109
+ case 1:
110
+ default:
111
+ imageMime = "image/jpeg";
112
+ break;
113
+ case 2:
114
+ imageMime = "image/png"
115
+ }
116
+ const blob = new Blob([buffer.slice(4)], { type: imageMime });
117
+
118
+ // const blob = new Blob([arrayBuffer], { type: 'image/png' }); // Assuming the image is a JPEG
119
+ const url = URL.createObjectURL(blob);
120
+
121
+ const canvas = canvasRef.current;
122
+ const ctx = canvas?.getContext('2d');
123
+
124
+ if (ctx) {
125
+ console.log("drawing");
126
+
127
+ const img = new Image();
128
+ img.onload = () => {
129
+ if (canvas) {
130
+ ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
131
+ }
132
+ URL.revokeObjectURL(url); // Clean up
133
+ };
134
+ img.src = url;
135
+ }
136
+ // this.dispatchEvent(new CustomEvent("b_preview", { detail: imageBlob }));
137
+ break;
138
+ default:
139
+ throw new Error(`Unknown binary websocket message of type ${eventType}`);
140
+ }
141
+ }, []);
142
+
143
+ useEffect(() => {
144
+ if (!data) {
145
+ setStatus("not-connected");
146
+ return;
147
+ }
148
+
149
+ return connectWS(data)
150
+ }, [connectWS, reconnectCounter])
151
+
152
+ return (
153
+ <div className='flex flex-col gap-2'>
154
+ <div className='flex gap-2'>
155
+ <Badge variant={'outline'} className='w-fit'>Status: {status}</Badge>
156
+ <Badge variant={'outline'} className='w-fit'>
157
+ {currentLog}
158
+ {status == "connected" &&!currentLog && "stating comfy ui"}
159
+ {status == "ready" &&!currentLog && " running"}
160
+ </Badge>
161
+ </div>
162
+ <canvas ref={canvasRef} className='rounded-lg' width="1024" height="1024"></canvas>
163
+
164
+
165
+ <Input
166
+ id="picture"
167
+ type="text"
168
+ value={prompt}
169
+ onChange={(e) => setPrompt(e.target.value)}
170
+ />
171
+ </div>
172
+ )
173
+ }
src/components/ui/badge.tsx ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import * as React from "react"
2
+ import { cva, type VariantProps } from "class-variance-authority"
3
+
4
+ import { cn } from "@/lib/utils"
5
+
6
+ const badgeVariants = cva(
7
+ "inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2",
8
+ {
9
+ variants: {
10
+ variant: {
11
+ default:
12
+ "border-transparent bg-primary text-primary-foreground hover:bg-primary/80",
13
+ secondary:
14
+ "border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80",
15
+ destructive:
16
+ "border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80",
17
+ outline: "text-foreground",
18
+ },
19
+ },
20
+ defaultVariants: {
21
+ variant: "default",
22
+ },
23
+ }
24
+ )
25
+
26
+ export interface BadgeProps
27
+ extends React.HTMLAttributes<HTMLDivElement>,
28
+ VariantProps<typeof badgeVariants> {}
29
+
30
+ function Badge({ className, variant, ...props }: BadgeProps) {
31
+ return (
32
+ <div className={cn(badgeVariants({ variant }), className)} {...props} />
33
+ )
34
+ }
35
+
36
+ export { Badge, badgeVariants }
src/server/generate.tsx CHANGED
@@ -7,7 +7,7 @@ const client = new ComfyDeployClient({
7
  apiToken: process.env.COMFY_API_TOKEN!,
8
  })
9
 
10
- export async function generate(prompt: string){
11
  return await client.run({
12
  deployment_id: process.env.COMFY_DEPLOYMENT_ID!,
13
  inputs: {
@@ -16,7 +16,7 @@ export async function generate(prompt: string){
16
  })
17
  }
18
 
19
- export async function generate_img(input_image: string){
20
  return await client.run({
21
  deployment_id: process.env.COMFY_DEPLOYMENT_ID_IMG_2_IMG!,
22
  inputs: {
@@ -25,7 +25,7 @@ export async function generate_img(input_image: string){
25
  })
26
  }
27
 
28
- export async function generate_img_with_controlnet(input_openpose_url: string, prompt: string){
29
  return await client.run({
30
  deployment_id: process.env.COMFY_DEPLOYMENT_ID_CONTROLNET!,
31
  inputs: {
@@ -35,14 +35,20 @@ export async function generate_img_with_controlnet(input_openpose_url: string, p
35
  })
36
  }
37
 
38
- export async function checkStatus(run_id: string){
39
  return await client.getRun(run_id)
40
  }
41
 
42
- export async function getUploadUrl(type: string, file_size: number){
43
  try {
44
  return await client.getUploadUrl(type, file_size)
45
  } catch (error) {
46
  console.log(error)
47
  }
48
  }
 
 
 
 
 
 
 
7
  apiToken: process.env.COMFY_API_TOKEN!,
8
  })
9
 
10
+ export async function generate(prompt: string) {
11
  return await client.run({
12
  deployment_id: process.env.COMFY_DEPLOYMENT_ID!,
13
  inputs: {
 
16
  })
17
  }
18
 
19
+ export async function generate_img(input_image: string) {
20
  return await client.run({
21
  deployment_id: process.env.COMFY_DEPLOYMENT_ID_IMG_2_IMG!,
22
  inputs: {
 
25
  })
26
  }
27
 
28
+ export async function generate_img_with_controlnet(input_openpose_url: string, prompt: string) {
29
  return await client.run({
30
  deployment_id: process.env.COMFY_DEPLOYMENT_ID_CONTROLNET!,
31
  inputs: {
 
35
  })
36
  }
37
 
38
+ export async function checkStatus(run_id: string) {
39
  return await client.getRun(run_id)
40
  }
41
 
42
+ export async function getUploadUrl(type: string, file_size: number) {
43
  try {
44
  return await client.getUploadUrl(type, file_size)
45
  } catch (error) {
46
  console.log(error)
47
  }
48
  }
49
+
50
+ export async function getWebsocketUrl() {
51
+ return await client.getWebsocketUrl({
52
+ deployment_id: process.env.COMFY_DEPLOYMENT_WS!,
53
+ })
54
+ }