Virtual-Try-On / src /background_processor.py
parokshsaxena
using shein sizes
1dddd5f
raw
history blame
9.6 kB
from PIL import Image, ImageEnhance
import cv2
import numpy as np
from preprocess.humanparsing.run_parsing import Parsing
parsing_model = Parsing(0)
class BackgroundProcessor:
@classmethod
def add_background(cls, human_img: Image, background_img: Image):
human_img = human_img.convert("RGB")
width = human_img.width
height = human_img.height
# Create mask image
parsed_img, _ = parsing_model(human_img)
mask_img = parsed_img.convert("L")
mask_img = mask_img.resize((width, height))
background_img = background_img.convert("RGB")
background_img = background_img.resize((width, height))
# Convert to numpy arrays
human_np = np.array(human_img)
mask_np = np.array(mask_img)
background_np = np.array(background_img)
# Ensure mask is 3-channel (RGB) for compatibility
mask_np = np.stack((mask_np,) * 3, axis=-1)
# Apply the mask to human_img
human_with_background = np.where(mask_np > 0, human_np, background_np)
# Convert back to PIL Image
result_img = Image.fromarray(human_with_background.astype('uint8'))
# Return or save the result
return result_img
@classmethod
def temp_v2(cls, human_img_path, background_img_path, mask_img_path):
# Load the images
foreground_img = cv2.imread(human_img_path).resize((768,1024)) # The segmented person image
background_img = cv2.imread(background_img_path) # The new background image
mask_img = cv2.imread(mask_img_path, cv2.IMREAD_GRAYSCALE) # The mask image from the human parser model
# Ensure the foreground image and the mask are the same size
if foreground_img.shape[:2] != mask_img.shape[:2]:
raise ValueError("Foreground image and mask must be the same size")
# Resize background image to match the size of the foreground image
background_img = cv2.resize(background_img, (foreground_img.shape[1], foreground_img.shape[0]))
# Create an inverted mask
mask_inv = cv2.bitwise_not(mask_img)
# Convert mask to 3 channels
mask_3ch = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2BGR)
mask_inv_3ch = cv2.cvtColor(mask_inv, cv2.COLOR_GRAY2BGR)
# Extract the person from the foreground image using the mask
person = cv2.bitwise_and(foreground_img, mask_3ch)
# Extract the background where the person is not present
background = cv2.bitwise_and(background_img, mask_inv_3ch)
# Combine the person and the new background
combined_img = cv2.add(person, background)
# Refine edges using Gaussian Blur (feathering technique)
blurred_combined_img = cv2.GaussianBlur(combined_img, (5, 5), 0)
# Post-processing: Adjust brightness, contrast, etc. (optional)
alpha = 1.2 # Contrast control (1.0-3.0)
beta = 20 # Brightness control (0-100)
post_processed_img = cv2.convertScaleAbs(blurred_combined_img, alpha=alpha, beta=beta)
# Save the final image
# cv2.imwrite('path_to_save_final_image.png', post_processed_img)
# Display the images (optional)
cv2.imshow('Foreground', foreground_img)
cv2.imshow('Background', background_img)
cv2.imshow('Mask', mask_img)
cv2.imshow('Combined', combined_img)
cv2.imshow('Post Processed', post_processed_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
return post_processed_img
@classmethod
def add_background_v3(cls, foreground_pil: Image, background_pil: Image):
foreground_pil= foreground_pil.convert("RGB")
width = foreground_pil.width
height = foreground_pil.height
# Create mask image
parsed_img, _ = parsing_model(foreground_pil)
mask_pil = parsed_img.convert("L")
# Apply a threshold to convert to binary image
# mask_pil = mask_pil.point(lambda p: 1 if p > 127 else 0, mode='1')
mask_pil = mask_pil.resize((width, height))
# Resize background image
background_pil = background_pil.convert("RGB")
background_pil = background_pil.resize((width, height))
# Load the images using PIL
#foreground_pil = Image.open(human_img_path).convert("RGB") # The segmented person image
#background_pil = Image.open(background_img_path).convert("RGB") # The new background image
#mask_pil = Image.open(mask_img_path).convert('L') # The mask image from the human parser model
# Resize the background to match the size of the foreground
#background_pil = background_pil.resize(foreground_pil.size)
# Resize mask
#mask_pil = mask_pil.resize(foreground_pil.size)
# Convert PIL images to OpenCV format
foreground_cv2 = cls.pil_to_cv2(foreground_pil)
background_cv2 = cls.pil_to_cv2(background_pil)
#mask_cv2 = pil_to_cv2(mask_pil)
mask_cv2 = np.array(mask_pil) # Directly convert to NumPy array without color conversion
# Ensure the mask is a single channel image
if len(mask_cv2.shape) == 3:
mask_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_BGR2GRAY)
# Threshold the mask to convert it to pure black and white
_, mask_cv2 = cv2.threshold(mask_cv2, 0, 255, cv2.THRESH_BINARY)
# Ensure the mask is a single channel image
#if len(mask_cv2.shape) == 3:
# mask_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_BGR2GRAY)
# Create an inverted mask
mask_inv_cv2 = cv2.bitwise_not(mask_cv2)
# Convert mask to 3 channels
mask_3ch_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_GRAY2BGR)
mask_inv_3ch_cv2 = cv2.cvtColor(mask_inv_cv2, cv2.COLOR_GRAY2BGR)
# Extract the person from the foreground image using the mask
person_cv2 = cv2.bitwise_and(foreground_cv2, mask_3ch_cv2)
# Extract the background where the person is not present
background_extracted_cv2 = cv2.bitwise_and(background_cv2, mask_inv_3ch_cv2)
# Combine the person and the new background
combined_cv2 = cv2.add(person_cv2, background_extracted_cv2)
# Refine edges using Gaussian Blur (feathering technique)
blurred_combined_cv2 = cv2.GaussianBlur(combined_cv2, (5, 5), 0)
# Convert the result back to PIL format
combined_pil = cls.cv2_to_pil(blurred_combined_cv2)
"""
# Post-processing: Adjust brightness, contrast, etc. (optional)
enhancer = ImageEnhance.Contrast(combined_pil)
post_processed_pil = enhancer.enhance(1.2) # Adjust contrast
enhancer = ImageEnhance.Brightness(post_processed_pil)
post_processed_pil = enhancer.enhance(1.2) # Adjust brightness
"""
# Save the final image
# post_processed_pil.save('path_to_save_final_image_1.png')
# Display the images (optional)
#foreground_pil.show(title="Foreground")
#background_pil.show(title="Background")
#mask_pil.show(title="Mask")
#combined_pil.show(title="Combined")
# post_processed_pil.show(title="Post Processed")
return combined_pil
@classmethod
def replace_background(cls, foreground_img_path: str, background_img_path: str):
# Load the input image (with alpha channel) and the background image
#input_image = cv2.imread(foreground_img_path, cv2.IMREAD_UNCHANGED)
input_image = cv2.imread(foreground_img_path)
background_image = cv2.imread(background_img_path)
# Ensure the input image has an alpha channel
if input_image.shape[2] != 4:
raise ValueError("Input image must have an alpha channel")
# Extract the alpha channel
alpha_channel = input_image[:, :, 3]
# Resize the background image to match the input image dimensions
background_image = cv2.resize(background_image, (input_image.shape[1], input_image.shape[0]))
# Convert alpha channel to 3 channels
alpha_channel_3ch = cv2.cvtColor(alpha_channel, cv2.COLOR_GRAY2BGR)
alpha_channel_3ch = alpha_channel_3ch / 255.0 # Normalize to 0-1
# Extract the BGR channels of the input image
input_bgr = input_image[:, :, :3]
# Blend the images using the alpha channel
foreground = cv2.multiply(alpha_channel_3ch, input_bgr.astype(float))
background = cv2.multiply(1.0 - alpha_channel_3ch, background_image.astype(float))
combined_image = cv2.add(foreground, background).astype(np.uint8)
# Save and display the result
cv2.imwrite('path_to_save_combined_image.png', combined_image)
cv2.imshow('Combined Image', combined_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
# Function to convert PIL Image to OpenCV format
@classmethod
def pil_to_cv2(cls, pil_image):
open_cv_image = np.array(pil_image)
# Convert RGB to BGR if it's a 3-channel image
if len(open_cv_image.shape) == 3:
open_cv_image = open_cv_image[:, :, ::-1].copy()
return open_cv_image
# Function to convert OpenCV format to PIL Image
@classmethod
def cv2_to_pil(cls, cv2_image):
# Convert BGR to RGB if it's a 3-channel image
if len(cv2_image.shape) == 3:
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(cv2_image)
return pil_image