import os import datasets from datasets import load_dataset, ClassLabel, concatenate_datasets import torch import numpy as np import random from PIL import Image import json import copy # import torchvision.transforms as T from torchvision import transforms import pickle import re from OmniGen import OmniGenProcessor from OmniGen.processor import OmniGenCollator class DatasetFromJson(torch.utils.data.Dataset): def __init__( self, json_file: str, image_path: str, processer: OmniGenProcessor, image_transform, max_input_length_limit: int = 18000, condition_dropout_prob: float = 0.1, keep_raw_resolution: bool = True, ): self.image_transform = image_transform self.processer = processer self.condition_dropout_prob = condition_dropout_prob self.max_input_length_limit = max_input_length_limit self.keep_raw_resolution = keep_raw_resolution self.data = load_dataset('json', data_files=json_file)['train'] self.image_path = image_path def process_image(self, image_file): if self.image_path is not None: image_file = os.path.join(self.image_path, image_file) image = Image.open(image_file).convert('RGB') return self.image_transform(image) def get_example(self, index): example = self.data[index] instruction, input_images, output_image = example['instruction'], example['input_images'], example['output_image'] if random.random() < self.condition_dropout_prob: instruction = '' input_images = None if input_images is not None: input_images = [self.process_image(x) for x in input_images] mllm_input = self.processer.process_multi_modal_prompt(instruction, input_images) output_image = self.process_image(output_image) return (mllm_input, output_image) def __getitem__(self, index): return self.get_example(index) for _ in range(8): try: mllm_input, output_image = self.get_example(index) if len(mllm_input['input_ids']) > self.max_input_length_limit: raise RuntimeError(f"cur number of tokens={len(mllm_input['input_ids'])}, larger than max_input_length_limit={self.max_input_length_limit}") return mllm_input, output_image except Exception as e: print("error when loading data: ", e) print(self.data[index]) index = random.randint(0, len(self.data)-1) raise RuntimeError("Too many bad data.") def __len__(self): return len(self.data) class TrainDataCollator(OmniGenCollator): def __init__(self, pad_token_id: int, hidden_size: int, keep_raw_resolution: bool): self.pad_token_id = pad_token_id self.hidden_size = hidden_size self.keep_raw_resolution = keep_raw_resolution def __call__(self, features): mllm_inputs = [f[0] for f in features] output_images = [f[1].unsqueeze(0) for f in features] target_img_size = [[x.size(-2), x.size(-1)] for x in output_images] all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size) if not self.keep_raw_resolution: output_image = torch.cat(output_image, dim=0) if len(pixel_values) > 0: all_pixel_values = torch.cat(all_pixel_values, dim=0) else: all_pixel_values = None data = {"input_ids": all_padded_input_ids, "attention_mask": all_attention_mask, "position_ids": all_position_ids, "input_pixel_values": all_pixel_values, "input_image_sizes": all_image_sizes, "padding_images": all_padding_images, "output_images": output_images, } return data