|
import os
|
|
import pandas as pd
|
|
import numpy as np
|
|
from PIL import Image
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import json
|
|
import random
|
|
import cv2
|
|
|
|
|
|
def canny_processor(image, low_threshold=100, high_threshold=200):
|
|
image = np.array(image)
|
|
image = cv2.Canny(image, low_threshold, high_threshold)
|
|
image = image[:, :, None]
|
|
image = np.concatenate([image, image, image], axis=2)
|
|
canny_image = Image.fromarray(image)
|
|
return canny_image
|
|
|
|
|
|
def c_crop(image):
|
|
width, height = image.size
|
|
new_size = min(width, height)
|
|
left = (width - new_size) / 2
|
|
top = (height - new_size) / 2
|
|
right = (width + new_size) / 2
|
|
bottom = (height + new_size) / 2
|
|
return image.crop((left, top, right, bottom))
|
|
|
|
class CustomImageDataset(Dataset):
|
|
def __init__(self, img_dir, img_size=512):
|
|
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
|
|
self.images.sort()
|
|
self.img_size = img_size
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
def __getitem__(self, idx):
|
|
try:
|
|
img = Image.open(self.images[idx])
|
|
img = c_crop(img)
|
|
img = img.resize((self.img_size, self.img_size))
|
|
hint = canny_processor(img)
|
|
img = torch.from_numpy((np.array(img) / 127.5) - 1)
|
|
img = img.permute(2, 0, 1)
|
|
hint = torch.from_numpy((np.array(hint) / 127.5) - 1)
|
|
hint = hint.permute(2, 0, 1)
|
|
json_path = self.images[idx].split('.')[0] + '.json'
|
|
prompt = json.load(open(json_path))['caption']
|
|
return img, hint, prompt
|
|
except Exception as e:
|
|
print(e)
|
|
return self.__getitem__(random.randint(0, len(self.images) - 1))
|
|
|
|
|
|
def loader(train_batch_size, num_workers, **args):
|
|
dataset = CustomImageDataset(**args)
|
|
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers)
|
|
|