|
import os |
|
import mlflow |
|
import torch |
|
from torch.utils.data import IterableDataset, DataLoader |
|
from torchvision import transforms |
|
from datasets import load_dataset |
|
from skimage.color import rgb2lab |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
class ColorizeIterableDataset(IterableDataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
def __iter__(self): |
|
for item in self.dataset: |
|
try: |
|
img = item['image'] |
|
if img.mode != 'RGB': |
|
img = img.convert('RGB') |
|
img = self.transform(img) |
|
|
|
|
|
lab = rgb2lab(img.permute(1, 2, 0).numpy()) |
|
|
|
|
|
l_chan = lab[:, :, 0] |
|
l_chan = (l_chan - 50) / 50 |
|
|
|
|
|
ab_chan = lab[:, :, 1:] |
|
ab_chan = ab_chan / 128 |
|
|
|
yield torch.Tensor(l_chan).unsqueeze(0), torch.Tensor(ab_chan).permute(2, 0, 1) |
|
except Exception as e: |
|
print(f"Error processing image: {str(e)}") |
|
continue |
|
|
|
def create_dataloaders(batch_size=32): |
|
try: |
|
print("Loading ImageNet dataset in streaming mode...") |
|
|
|
dataset = load_dataset("imagenet-1k", split="train", streaming=True) |
|
print("Dataset loaded in streaming mode.") |
|
|
|
print("Creating custom dataset...") |
|
|
|
train_dataset = ColorizeIterableDataset(dataset) |
|
print("Custom dataset created.") |
|
|
|
print("Creating dataloader...") |
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4) |
|
print("Dataloader created.") |
|
|
|
return train_dataloader |
|
except Exception as e: |
|
print(f"Error in create_dataloaders: {str(e)}") |
|
return None |
|
|
|
def test_data_ingestion(): |
|
print("Testing data ingestion...") |
|
try: |
|
dataloader = create_dataloaders(batch_size=4) |
|
if dataloader is None: |
|
raise Exception("Dataloader creation failed") |
|
|
|
|
|
for sample_batch in dataloader: |
|
if len(sample_batch) != 2: |
|
raise Exception(f"Unexpected batch format: {len(sample_batch)} elements instead of 2") |
|
|
|
l_chan, ab_chan = sample_batch |
|
if l_chan.shape != torch.Size([4, 1, 256, 256]) or ab_chan.shape != torch.Size([4, 2, 256, 256]): |
|
raise Exception(f"Unexpected tensor shapes: L={l_chan.shape}, AB={ab_chan.shape}") |
|
|
|
print("Data ingestion test passed.") |
|
return True |
|
except Exception as e: |
|
print(f"Data ingestion test failed: {str(e)}") |
|
return False |
|
|
|
if __name__ == "__main__": |
|
try: |
|
print("Starting data ingestion pipeline...") |
|
mlflow.start_run(run_name="data_ingestion") |
|
|
|
try: |
|
|
|
print("Logging parameters...") |
|
mlflow.log_param("batch_size", 32) |
|
mlflow.log_param("dataset", "imagenet-1k") |
|
print("Parameters logged.") |
|
|
|
|
|
print("Creating dataloaders...") |
|
train_dataloader = create_dataloaders(batch_size=32) |
|
if train_dataloader is None: |
|
raise Exception("Failed to create dataloader") |
|
print("Dataloaders created successfully.") |
|
|
|
|
|
print("Logging sample batch...") |
|
for sample_batch in train_dataloader: |
|
l_chan, ab_chan = sample_batch |
|
|
|
|
|
sample_input = l_chan[0].numpy() |
|
mlflow.log_image(sample_input, "sample_input_l_channel.png") |
|
|
|
|
|
sample_target = ab_chan[0].permute(1, 2, 0).numpy() |
|
mlflow.log_image(sample_target, "sample_target_ab_channels.png") |
|
|
|
print("Sample batch logged.") |
|
break |
|
|
|
print("Data ingestion pipeline completed successfully.") |
|
|
|
except Exception as e: |
|
print(f"Error in data ingestion pipeline: {str(e)}") |
|
mlflow.log_param("error", str(e)) |
|
|
|
finally: |
|
mlflow.end_run() |
|
|
|
except Exception as e: |
|
print(f"Critical error in main execution: {str(e)}") |
|
|
|
if test_data_ingestion(): |
|
print("All tests passed.") |
|
else: |
|
print("Tests failed.") |