Virtual-Try-On / src /background_processor.py
parokshsaxena
removing naize harmonization, adding functions for creating masks
0c774a3
raw
history blame
17 kB
import os
import requests
import logging
from PIL import Image, ImageEnhance
import cv2
import numpy as np
from preprocess.humanparsing.run_parsing import Parsing
from src.image_format_convertor import ImageFormatConvertor
REMOVE_BG_KEY = os.getenv('REMOVE_BG_KEY')
parsing_model = Parsing(0)
class BackgroundProcessor:
DeprecationWarning("Created only for testing. Not in use")
@classmethod
def add_background(cls, human_img: Image, background_img: Image):
human_img = human_img.convert("RGB")
width = human_img.width
height = human_img.height
# Create mask image
parsed_img, _ = parsing_model(human_img)
mask_img = parsed_img.convert("L")
mask_img = mask_img.resize((width, height))
background_img = background_img.convert("RGB")
background_img = background_img.resize((width, height))
# Convert to numpy arrays
human_np = np.array(human_img)
mask_np = np.array(mask_img)
background_np = np.array(background_img)
# Ensure mask is 3-channel (RGB) for compatibility
mask_np = np.stack((mask_np,) * 3, axis=-1)
# Apply the mask to human_img
human_with_background = np.where(mask_np > 0, human_np, background_np)
# Convert back to PIL Image
result_img = Image.fromarray(human_with_background.astype('uint8'))
# Return or save the result
return result_img
DeprecationWarning("Created only for testing. Not in use")
@classmethod
def add_background_v3(cls, foreground_pil: Image, background_pil: Image):
foreground_pil= foreground_pil.convert("RGB")
width = foreground_pil.width
height = foreground_pil.height
# Create mask image
parsed_img, _ = parsing_model(foreground_pil)
mask_pil = parsed_img.convert("L")
# Apply a threshold to convert to binary image
# mask_pil = mask_pil.point(lambda p: 1 if p > 127 else 0, mode='1')
mask_pil = mask_pil.resize((width, height))
# Resize background image
background_pil = background_pil.convert("RGB")
background_pil = background_pil.resize((width, height))
# Load the images using PIL
#foreground_pil = Image.open(human_img_path).convert("RGB") # The segmented person image
#background_pil = Image.open(background_img_path).convert("RGB") # The new background image
#mask_pil = Image.open(mask_img_path).convert('L') # The mask image from the human parser model
# Resize the background to match the size of the foreground
#background_pil = background_pil.resize(foreground_pil.size)
# Resize mask
#mask_pil = mask_pil.resize(foreground_pil.size)
# Convert PIL images to OpenCV format
foreground_cv2 = ImageFormatConvertor.pil_to_cv2(foreground_pil)
background_cv2 = ImageFormatConvertor.pil_to_cv2(background_pil)
#mask_cv2 = pil_to_cv2(mask_pil)
mask_cv2 = np.array(mask_pil) # Directly convert to NumPy array without color conversion
# Ensure the mask is a single channel image
if len(mask_cv2.shape) == 3:
mask_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_BGR2GRAY)
# Threshold the mask to convert it to pure black and white
_, mask_cv2 = cv2.threshold(mask_cv2, 0, 255, cv2.THRESH_BINARY)
# Ensure the mask is a single channel image
#if len(mask_cv2.shape) == 3:
# mask_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_BGR2GRAY)
# Create an inverted mask
mask_inv_cv2 = cv2.bitwise_not(mask_cv2)
# Convert mask to 3 channels
mask_3ch_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_GRAY2BGR)
mask_inv_3ch_cv2 = cv2.cvtColor(mask_inv_cv2, cv2.COLOR_GRAY2BGR)
# Extract the person from the foreground image using the mask
person_cv2 = cv2.bitwise_and(foreground_cv2, mask_3ch_cv2)
# Extract the background where the person is not present
background_extracted_cv2 = cv2.bitwise_and(background_cv2, mask_inv_3ch_cv2)
# Combine the person and the new background
combined_cv2 = cv2.add(person_cv2, background_extracted_cv2)
# Refine edges using Gaussian Blur (feathering technique)
blurred_combined_cv2 = cv2.GaussianBlur(combined_cv2, (5, 5), 0)
# Convert the result back to PIL format
combined_pil = ImageFormatConvertor.cv2_to_pil(blurred_combined_cv2)
"""
# Post-processing: Adjust brightness, contrast, etc. (optional)
enhancer = ImageEnhance.Contrast(combined_pil)
post_processed_pil = enhancer.enhance(1.2) # Adjust contrast
enhancer = ImageEnhance.Brightness(post_processed_pil)
post_processed_pil = enhancer.enhance(1.2) # Adjust brightness
"""
# Save the final image
# post_processed_pil.save('path_to_save_final_image_1.png')
# Display the images (optional)
#foreground_pil.show(title="Foreground")
#background_pil.show(title="Background")
#mask_pil.show(title="Mask")
#combined_pil.show(title="Combined")
# post_processed_pil.show(title="Post Processed")
return combined_pil
DeprecationWarning("Created only for testing. Not in use")
@classmethod
def replace_background(cls, foreground_img_path: str, background_img_path: str):
# Load the input image (with alpha channel) and the background image
#input_image = cv2.imread(foreground_img_path, cv2.IMREAD_UNCHANGED)
# background_image = cv2.imread(background_img_path)
foreground_img_pil = Image.open(foreground_img_path)
width = foreground_img_pil.width
height = foreground_img_pil.height
background_image_pil = Image.open(background_img_path)
background_image_pil = background_image_pil.resize((width, height))
input_image = ImageFormatConvertor.pil_to_cv2(foreground_img_pil)
background_image = ImageFormatConvertor.pil_to_cv2(background_image_pil)
# Ensure the input image has an alpha channel
if input_image.shape[2] != 4:
raise ValueError("Input image must have an alpha channel")
# Extract the alpha channel
alpha_channel = input_image[:, :, 3]
# Resize the background image to match the input image dimensions
background_image = cv2.resize(background_image, (input_image.shape[1], input_image.shape[0]))
# Convert alpha channel to 3 channels
alpha_channel_3ch = cv2.cvtColor(alpha_channel, cv2.COLOR_GRAY2BGR)
alpha_channel_3ch = alpha_channel_3ch / 255.0 # Normalize to 0-1
# Extract the BGR channels of the input image
input_bgr = input_image[:, :, :3]
background_bgr = background_image[:,:,:3]
# Blend the images using the alpha channel
foreground = cv2.multiply(alpha_channel_3ch, input_bgr.astype(float))
background = cv2.multiply(1.0 - alpha_channel_3ch, background_bgr.astype(float))
combined_image = cv2.add(foreground, background).astype(np.uint8)
# Save and display the result
cv2.imwrite('path_to_save_combined_image.png', combined_image)
cv2.imshow('Combined Image', combined_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
@classmethod
def replace_background_with_removebg(cls, foreground_img_pil: Image, background_image_pil: Image):
foreground_img_pil= foreground_img_pil.convert("RGB")
width = foreground_img_pil.width
height = foreground_img_pil.height
# Resize background image
background_image_pil = background_image_pil.convert("RGB")
background_image_pil = background_image_pil.resize((width, height))
#foreground_img_pil = Image.open(foreground_img_path)
#width = foreground_img_pil.width
#height = foreground_img_pil.height
#background_image_pil = Image.open(background_img_path)
#background_image_pil = background_image_pil.resize((width, height))
# Do color transfer of background to foreground to adjust lighting condition
#foreground_img_pil = cls.color_transfer(foreground_img_pil, background_image_pil)
foreground_binary = ImageFormatConvertor.pil_image_to_binary_data(foreground_img_pil)
background_binary = ImageFormatConvertor.pil_image_to_binary_data(background_image_pil)
combined_img_pil = cls.remove_bg(foreground_binary, background_binary)
return combined_img_pil
@classmethod
def remove_bg(cls, foreground_binary: str, background_binary: str):
# ref: https://www.remove.bg/api#api-reference
url = "https://api.remove.bg/v1.0/removebg"
# using form-data as passing binary data is not supported in application/json
files = {
"image_file": ('foreground.png', foreground_binary, 'image/png'),
"bg_image_file": ('background.png', background_binary, 'image/png')
}
# get output image in same resolution as input
payload = {
"size": "full",
"shadow_type": "3D"
}
headers = {
"accept": "image/*",
'X-Api-Key': REMOVE_BG_KEY
}
remove_bg_request = requests.post(url, files=files, data=payload, headers=headers, timeout=20)
if remove_bg_request.status_code == 200:
image_content = remove_bg_request.content
pil_image = ImageFormatConvertor.binary_data_to_pil_image(image_content)
return pil_image
logging.error(f"failed to use remove bg. Status: {remove_bg_request.status_code}. Resp: {remove_bg_request.content}")
return None
@classmethod
def create_mask(cls, foreground_path: str, mask_path: str):
"""
Given foreground image path with background removed, create a maska and save it in mask_path
"""
# Load the foreground image with alpha channel
foreground = Image.open(foreground_path)
# Convert to RGBA if not already
foreground = foreground.convert("RGBA")
# Create the mask from the alpha channel
alpha_channel = np.array(foreground.split()[-1])
# Create a binary mask where alpha > 0 is white (255) and alpha == 0 is black (0)
mask = np.where(alpha_channel > 0, 255, 0).astype(np.uint8)
# Save the mask to a file
Image.fromarray(mask).save(mask_path)
@classmethod
def get_minimal_bounding_box(cls, foreground_pil: Image):
"""
Result x1,y1,x2,y2 ie cordinate of bottom left and top right
"""
# convert to cv2
foreground = ImageFormatConvertor.pil_to_cv2(foreground_pil)
# Ensure the image has an alpha channel (transparency)
if foreground.shape[2] == 4:
# Extract the alpha channel
alpha_channel = foreground[:, :, 3]
# Create a binary image from the alpha channel
_, binary_image = cv2.threshold(alpha_channel, 1, 255, cv2.THRESH_BINARY)
else:
# If there is no alpha channel, convert the image to grayscale
gray_image = cv2.cvtColor(foreground, cv2.COLOR_BGR2GRAY)
# Apply binary thresholding
_, binary_image = cv2.threshold(gray_image, 1, 255, cv2.THRESH_BINARY)
# Find all non-zero points (non-background)
non_zero_points = cv2.findNonZero(binary_image)
# Get the minimal bounding rectangle
if non_zero_points is not None:
x, y, w, h = cv2.boundingRect(non_zero_points)
"""
# Optionally, draw the bounding box on the image for visualization
output_image = foreground.copy()
cv2.rectangle(output_image, (x, y), (x+w, y+h), (0, 255, 0, 255), 2)
# Save or display the output image
output_image_pil = ImageFormatConvertor.cv2_to_pil(output_image)
output_image_pil.save('output_with_bounding_box.png')
"""
return (x, y, x + w, y + h)
else:
return 0,0,w,h
@classmethod
def color_transfer(cls, source_pil: Image, target_pil: Image) -> Image:
# NOT IN USE as output color was not good
source = ImageFormatConvertor.pil_to_cv2(source_pil)
# Resize background image
width, height = source_pil.width, source_pil.height
target_pil = target_pil.convert("RGB")
target_pil = target_pil.resize((width, height))
target = ImageFormatConvertor.pil_to_cv2(target_pil)
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB)
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB)
# Compute the mean and standard deviation of the source and target images
source_mean, source_std = cv2.meanStdDev(source)
target_mean, target_std = cv2.meanStdDev(target)
#Reshape the mean and std to (1, 1, 3) so they can be broadcast correctly
source_mean = source_mean.reshape((1, 1, 3))
source_std = source_std.reshape((1, 1, 3))
target_mean = target_mean.reshape((1, 1, 3))
target_std = target_std.reshape((1, 1, 3))
# Subtract the mean from the source image
result = (source - source_mean) * (target_std / source_std) + target_mean
result = np.clip(result, 0, 255).astype(np.uint8)
res = cv2.cvtColor(result, cv2.COLOR_LAB2BGR)
res_pil = ImageFormatConvertor.cv2_to_pil(res)
return res_pil
@classmethod
def intensity_transfer(cls, source_pil: Image, target_pil: Image) -> Image:
"""
Transfers the intensity distribution from the target image to the source image.
Parameters:
source (np.ndarray): The source image (foreground) to be harmonized.
target (np.ndarray): The target image (background) whose intensity distribution is to be matched.
eps (float): A small value to avoid division by zero.
Returns:
np.ndarray: The intensity-transferred source image.
"""
source = ImageFormatConvertor.pil_to_cv2(source_pil)
# Resize background image
width, height = source_pil.width, source_pil.height
target_pil = target_pil.convert("RGB")
target_pil = target_pil.resize((width, height))
target = ImageFormatConvertor.pil_to_cv2(target_pil)
source_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB)
target_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB)
# Compute the mean and standard deviation of the L channel (intensity) of the source and target images
source_mean, source_std = cv2.meanStdDev(source_lab[:, :, 0])
target_mean, target_std = cv2.meanStdDev(target_lab[:, :, 0])
# Reshape the mean and std to (1, 1, 1) so they can be broadcast correctly
source_mean = source_mean.reshape((1, 1, 1))
source_std = source_std.reshape((1, 1, 1))
target_mean = target_mean.reshape((1, 1, 1))
target_std = target_std.reshape((1, 1, 1))
# Transfer the intensity (L channel)
result_l = (source_lab[:, :, 0] - source_mean) * (target_std / source_std) + target_mean
result_l = np.clip(result_l, 0, 255).astype(np.uint8)
# Combine the transferred L channel with the original A and B channels
result_lab = np.copy(source_lab)
result_lab[:, :, 0] = result_l
# return cv2.cvtColor(result_lab, cv2.COLOR_LAB2BGR)
res = cv2.cvtColor(result_lab, cv2.COLOR_LAB2BGR)
res_pil = ImageFormatConvertor.cv2_to_pil(res)
return res_pil
@classmethod
def match_color(cls, source_pil: Image, target_pil: Image):
source = ImageFormatConvertor.pil_to_cv2(source_pil)
# Resize background image
width, height = source_pil.width, source_pil.height
target_pil = target_pil.convert("RGB")
target_pil = target_pil.resize((width, height))
target = ImageFormatConvertor.pil_to_cv2(target_pil)
matched_foreground = cv2.cvtColor(source, cv2.COLOR_BGR2LAB)
matched_background = cv2.cvtColor(target, cv2.COLOR_BGR2LAB)
# Match the histograms
for i in range(3):
matched_foreground[:, :, i] = cv2.equalizeHist(matched_foreground[:, :, i])
matched_background[:, :, i] = cv2.equalizeHist(matched_background[:, :, i])
matched_foreground = cv2.cvtColor(matched_foreground, cv2.COLOR_LAB2BGR)
matched_background = cv2.cvtColor(matched_background, cv2.COLOR_LAB2BGR)
matched_foreground_pil = ImageFormatConvertor.cv2_to_pil(matched_foreground)
matched_background_pil = ImageFormatConvertor.cv2_to_pil(matched_background)
return matched_foreground_pil, matched_background_pil