dovedovepigeon's picture
Update app.py
b473941
import gradio as gr
import openai
import os
import json
from PIL import Image, ImageDraw
import io
import requests
canvas_width = 500
canvas_height = 400
html = f"""
<head>
<style>
#selectRect {{
position: absolute;
border: 1px dashed red;
background-color: rgba(255, 0, 0, 0.3);
}}
</style>
</head>
<body>
<canvas id="canvas-root", width="{canvas_width}", height="{canvas_height}"></canvas>
<div id="selectRect"></div>
</body>
"""
scripts = """
async () => {
let isSelecting = false;
let startX, startY, endX, endY;
const canvas = document.getElementById('canvas-root');
const ctx = canvas.getContext('2d');
const canvasRect = canvas.getBoundingClientRect();
const selectRect = document.getElementById('selectRect');
const coordinatesElement = document.querySelector('#rectangle textarea');
function handleMouseDown(event) {
startX = event.clientX - canvasRect.left;
startY = event.clientY - canvasRect.top;
if (startX >= 0 && startY >= 0 && startX <= canvasRect.width && startY <= canvasRect.height) {
isSelecting = true;
}
}
function handleMouseMove(event) {
if (isSelecting) {
endX = Math.min(event.clientX - canvasRect.left, canvasRect.width);
endY = Math.min(event.clientY - canvasRect.top, canvasRect.height);
endX = Math.max(0, endX);
endY = Math.max(0, endY);
const left = Math.min(startX, endX);
const top = Math.min(startY, endY);
const width = Math.abs(endX - startX);
const height = Math.abs(endY - startY);
selectRect.style.left = left + 'px';
selectRect.style.top = top + 'px';
selectRect.style.width = width + 'px';
selectRect.style.height = height + 'px';
coordinatesElement.value = `{"left": ${left}, "top": ${top}, "width": ${width}, "height": ${height}}`;
coordinatesElement.dispatchEvent(new CustomEvent("input"))
}
}
function handleMouseUp() {
isSelecting = false;
}
document.addEventListener('mousedown', handleMouseDown);
document.addEventListener('mousemove', handleMouseMove);
document.addEventListener('mouseup', handleMouseUp);
}
"""
image_change = """
async () => {
const canvas = document.getElementById('canvas-root');
const ctx= canvas.getContext('2d');
const canvasRect = canvas.getBoundingClientRect();
const selectRect = document.getElementById('selectRect');
selectRect.style.left = 0;
selectRect.style.top = 0;
selectRect.style.width = 0;
selectRect.style.height = 0;
ctx.clearRect(0, 0, canvasRect.width, canvasRect.height);
var img = document.querySelector('#input_image img');
img.onload = function(){
if ((img.naturalWidth / canvasRect.width) > (img.naturalHeight / canvasRect.height)) {
width = canvasRect.width;
height = img.naturalHeight * (width / img.naturalWidth);
} else {
height = canvasRect.height;
width = img.naturalWidth * (height / img.naturalHeight);
}
ctx.drawImage(img, 0, 0, width, height);
}
}
"""
def pil_to_bytes(pil_image, format='PNG'):
image_bytes = io.BytesIO()
pil_image.save(image_bytes, format=format)
return image_bytes.getvalue()
def expand2square(image, background_color):
width, height = image.size
longest = max(width, height)
result = Image.new(image.mode, (longest, longest), background_color)
result.paste(image, (0, 0))
return result.resize((2048, 2048))
def gen_mask(image, left, top, right, bottom):
mask = Image.new("RGBA", image.size, (0, 0, 0, 255))
width = image.size[0]
height = image.size[1]
draw = ImageDraw.Draw(mask)
draw.rectangle(
[(left*width, top*height), (right*width, bottom*height)], fill=(255, 255, 255, 0)
)
return mask
def create_edit(image, rect, prompt, api_key, api_organization=None):
openai.organization = api_organization
openai.api_key = api_key
rect = json.loads(rect)
image.putalpha(alpha=255)
square_image = expand2square(image, "black")
left, top, width, height = rect["left"], rect["top"], rect["width"], rect["height"]
left, top, right, bottom = left / canvas_width, top / canvas_height, (left + width) / canvas_width, (top + height) / canvas_height
response = openai.Image.create_edit(
image=pil_to_bytes(square_image),
mask=pil_to_bytes(gen_mask(square_image, left, top, right, bottom)),
prompt=prompt,
n=1,
size="512x512"
)
edited_image_url = response['data'][0]['url']
edited_image = requests.get(edited_image_url)
edited_image = Image.open(io.BytesIO(edited_image.content))
raw_width, raw_height = image.size
raw_longest = max(raw_width, raw_height)
crop_width = raw_width * edited_image.size[0] / raw_longest
crop_height = raw_height * edited_image.size[1] / raw_longest
croped_edited_image = edited_image.crop((0,0,crop_width, crop_height))
return croped_edited_image
with gr.Blocks() as demo:
with gr.Accordion("OpenAI API Settings", open=False):
api_key = gr.Textbox(label="OpenAI API key", placeholder="OpenAI API key")
api_organization = gr.Textbox(label="OpenAI API organization", placeholder="OpenAI API organization (optional)")
with gr.Column():
with gr.Row():
with gr.Column():
prompt_text = gr.Textbox(label="Prompt")
prompt_examples = gr.Examples(
examples=[
"White plate.",
"A cherry on top of the pasta.",
"Curry.",
],
inputs=[prompt_text],
outputs=None,
)
in_image = gr.Image(label="Input", elem_id="input_image", type="pil")
image_examples = gr.Examples(
examples=[
"images/001.jpg",
"images/002.jpg",
"images/003.jpg",
],
inputs=[in_image],
outputs=None,
)
out_image = gr.Image(label="Output")
with gr.Column():
gr.Markdown(
"""
# Edit領域の指定
ドラッグで編集対象のマスクの領域を指定してください。
""")
input_mic = gr.HTML(html)
btn = gr.Button(value="Image Edit")
rect_text = gr.Textbox(elem_id="rectangle", visible=False)
in_image.change(None, inputs=None, outputs=None, _js=image_change)
btn.click(create_edit, inputs=[in_image, rect_text, prompt_text, api_key, api_organization], outputs=[out_image])
demo.load(_js=scripts)
demo.launch()