|
import numpy as np
|
|
import pandas as pd
|
|
from PIL import Image
|
|
from collections import defaultdict
|
|
|
|
import streamlit as st
|
|
from streamlit_drawable_canvas import st_canvas
|
|
|
|
import matplotlib as mpl
|
|
|
|
from model import device, segment_image, inpaint
|
|
|
|
|
|
|
|
def closest_number(n, m=8):
|
|
""" Obtains closest number to n that is divisble by m """
|
|
return int(n/m) * m
|
|
|
|
|
|
def get_mask_from_rectangles(image, mask, height, width, drawing_mode='rect'):
|
|
|
|
canvas_result = st_canvas(
|
|
fill_color="rgba(255, 165, 0, 0.3)",
|
|
stroke_width=2,
|
|
stroke_color="#000000",
|
|
background_image=image,
|
|
update_streamlit=True,
|
|
height=height,
|
|
width=width,
|
|
drawing_mode=drawing_mode,
|
|
point_display_radius=5,
|
|
key="canvas",
|
|
)
|
|
|
|
|
|
if canvas_result.json_data is not None:
|
|
objects = pd.json_normalize(canvas_result.json_data["objects"])
|
|
for col in objects.select_dtypes(include=["object"]).columns:
|
|
objects[col] = objects[col].astype("str")
|
|
|
|
if len(objects) > 0:
|
|
left_coords = objects.left.to_numpy()
|
|
top_coords = objects.top.to_numpy()
|
|
right_coords = left_coords + objects.width.to_numpy()
|
|
bottom_coords = top_coords + objects.height.to_numpy()
|
|
|
|
|
|
for (left, top, right, bottom) in zip(left_coords, top_coords, right_coords, bottom_coords):
|
|
cropped = image.crop((left, top, right, bottom))
|
|
st.image(cropped)
|
|
mask[top:bottom, left:right] = 255
|
|
|
|
st.header("Mask Created!")
|
|
st.image(mask)
|
|
|
|
return mask
|
|
|
|
|
|
def get_mask(image, edit_method, height, width):
|
|
mask = np.zeros((height, width), dtype=np.uint8)
|
|
|
|
if edit_method == "AutoSegment Area":
|
|
|
|
|
|
seg_prediction, segment_labels = segment_image(image)
|
|
seg = seg_prediction['segmentation'].cpu().numpy()
|
|
viridis = mpl.colormaps.get_cmap('viridis').resampled(np.max(seg))
|
|
seg_image = Image.fromarray(np.uint8(viridis(seg)*255))
|
|
|
|
st.image(seg_image)
|
|
|
|
|
|
seg_selections = st.multiselect("Choose segments", zip(segment_labels.keys(), segment_labels.values()))
|
|
if seg_selections:
|
|
tgts = []
|
|
for s in seg_selections:
|
|
tgts.append(s[0])
|
|
|
|
mask = Image.fromarray(np.array([(seg == t) for t in tgts]).sum(axis=0).astype(np.uint8)*255)
|
|
st.header("Mask Created!")
|
|
st.image(mask)
|
|
|
|
elif edit_method == "Draw Custom Area":
|
|
mask = get_mask_from_rectangles(image, mask, height, width)
|
|
|
|
|
|
return mask
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
st.title("Stable Edit")
|
|
st.title("Edit your photos with Stable Diffusion!")
|
|
|
|
st.write(f"Device found: {device}")
|
|
|
|
sf = st.text_input("Please enter resizing scale factor to downsize image (default=2)", value="2")
|
|
try:
|
|
sf = int(sf)
|
|
except:
|
|
sf.write("Error with input scale factor, setting to default value of 2, please re-enter above to change it")
|
|
sf = 2
|
|
|
|
|
|
filename = st.file_uploader("upload an image")
|
|
|
|
if filename:
|
|
image = Image.open(filename)
|
|
|
|
width, height = image.size
|
|
width, height = closest_number(width/sf), closest_number(height/sf)
|
|
image = image.resize((width, height))
|
|
|
|
st.image(image)
|
|
|
|
|
|
|
|
edit_method = st.selectbox("Select Edit Method", ("AutoSegment Area", "Draw Custom Area"))
|
|
|
|
if edit_method:
|
|
mask = get_mask(image, edit_method, height, width)
|
|
|
|
|
|
prompt = st.text_input("Please enter prompt for image inpainting", value="")
|
|
|
|
if prompt:
|
|
st.write("Inpainting Images, patience is a virtue and this will take a while to run on a CPU :)")
|
|
images = inpaint(image, mask, width, height, prompt=prompt, seed=0, guidance_scale=17.5, num_samples=3)
|
|
|
|
|
|
st.write("Original Image")
|
|
st.image(image)
|
|
for i, img in enumerate(images, 1):
|
|
st.write(f"result: {i}")
|
|
st.image(img)
|
|
|
|
|