File size: 5,016 Bytes
02f3f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import mlflow
import torch
from torch.utils.data import IterableDataset, DataLoader
from torchvision import transforms
from datasets import load_dataset
from skimage.color import rgb2lab
from PIL import Image
import numpy as np


class ColorizeIterableDataset(IterableDataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])

    def __iter__(self):
        for item in self.dataset:
            try:
                img = item['image']
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                img = self.transform(img)
                
                # Convert to LAB color space
                lab = rgb2lab(img.permute(1, 2, 0).numpy())
                
                # Normalize L channel to range [-1, 1]
                l_chan = lab[:, :, 0]
                l_chan = (l_chan - 50) / 50
                
                # Normalize AB channels to range [-1, 1]
                ab_chan = lab[:, :, 1:]
                ab_chan = ab_chan / 128
                
                yield torch.Tensor(l_chan).unsqueeze(0), torch.Tensor(ab_chan).permute(2, 0, 1)
            except Exception as e:
                print(f"Error processing image: {str(e)}")
                continue

def create_dataloaders(batch_size=32):
    try:
        print("Loading ImageNet dataset in streaming mode...")
        # Load ImageNet dataset from Hugging Face in streaming mode
        dataset = load_dataset("imagenet-1k", split="train", streaming=True)
        print("Dataset loaded in streaming mode.")
        
        print("Creating custom dataset...")
        # Create custom dataset
        train_dataset = ColorizeIterableDataset(dataset)
        print("Custom dataset created.")
        
        print("Creating dataloader...")
        # Create dataloader
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4)
        print("Dataloader created.")
        
        return train_dataloader
    except Exception as e:
        print(f"Error in create_dataloaders: {str(e)}")
        return None

def test_data_ingestion():
    print("Testing data ingestion...")
    try:
        dataloader = create_dataloaders(batch_size=4)
        if dataloader is None:
            raise Exception("Dataloader creation failed")
        
        # Get the first batch
        for sample_batch in dataloader:
            if len(sample_batch) != 2:
                raise Exception(f"Unexpected batch format: {len(sample_batch)} elements instead of 2")
            
            l_chan, ab_chan = sample_batch
            if l_chan.shape != torch.Size([4, 1, 256, 256]) or ab_chan.shape != torch.Size([4, 2, 256, 256]):
                raise Exception(f"Unexpected tensor shapes: L={l_chan.shape}, AB={ab_chan.shape}")
            
            print("Data ingestion test passed.")
            return True
    except Exception as e:
        print(f"Data ingestion test failed: {str(e)}")
        return False

if __name__ == "__main__":
    try:
        print("Starting data ingestion pipeline...")
        mlflow.start_run(run_name="data_ingestion")
        
        try:
            # Log parameters
            print("Logging parameters...")
            mlflow.log_param("batch_size", 32)
            mlflow.log_param("dataset", "imagenet-1k")
            print("Parameters logged.")
            
            # Create dataloaders
            print("Creating dataloaders...")
            train_dataloader = create_dataloaders(batch_size=32)
            if train_dataloader is None:
                raise Exception("Failed to create dataloader")
            print("Dataloaders created successfully.")
            
            # Log a sample batch
            print("Logging sample batch...")
            for sample_batch in train_dataloader:
                l_chan, ab_chan = sample_batch
                
                # Log sample input (L channel)
                sample_input = l_chan[0].numpy()
                mlflow.log_image(sample_input, "sample_input_l_channel.png")
                
                # Log sample target (AB channels)
                sample_target = ab_chan[0].permute(1, 2, 0).numpy()
                mlflow.log_image(sample_target, "sample_target_ab_channels.png")
                
                print("Sample batch logged.")
                break  # We only need one batch for logging
            
            print("Data ingestion pipeline completed successfully.")
        
        except Exception as e:
            print(f"Error in data ingestion pipeline: {str(e)}")
            mlflow.log_param("error", str(e))
        
        finally:
            mlflow.end_run()
    
    except Exception as e:
        print(f"Critical error in main execution: {str(e)}")

    if test_data_ingestion():
        print("All tests passed.")
    else:
        print("Tests failed.")