Rohil Bansal commited on
Commit
02f3f24
1 Parent(s): 8e66116

huggingface spaces commit.

Browse files
Files changed (14) hide show
  1. .gitattributes +1 -35
  2. .gitignore +7 -0
  3. Dockerfile +10 -0
  4. README.md +11 -12
  5. app-mlflow.py +72 -0
  6. app.py +123 -0
  7. colorizer_pipeline.py +379 -0
  8. data_ingestion.py +138 -0
  9. inference.py +117 -0
  10. instructions.txt +13 -0
  11. model.py +144 -0
  12. requirements.txt +7 -0
  13. run_colorizer.py +118 -0
  14. train.py +235 -0
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ .env
3
+ __pycache__
4
+ *.pth.tar
5
+ results/
6
+ mlruns/
7
+ # checkpoints/
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ CMD ["python", "data_ingestion.py"]
README.md CHANGED
@@ -1,12 +1,11 @@
1
- ---
2
- title: Pix2pixcolorizer
3
- emoji: 📊
4
- colorFrom: gray
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.42.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # Image Colorizer
2
+
3
+ This is a Gradio app that colorizes grayscale images using a trained GAN model. Upload a grayscale image, and the app will return a colorized version of the image.
4
+
5
+ ## How to use
6
+
7
+ 1. Upload a grayscale image using the provided interface.
8
+ 2. Wait for the model to process the image.
9
+ 3. View the colorized result!
10
+
11
+ This app uses a GAN model trained on the ImageNet dataset to add color to grayscale images.
 
app-mlflow.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import mlflow
4
+ import numpy as np
5
+ from PIL import Image
6
+ from skimage.color import rgb2lab, lab2rgb
7
+ from torchvision import transforms
8
+
9
+ from model import Generator
10
+
11
+ EXPERIMENT_NAME = "Colorizer_Experiment"
12
+ RUN_ID = "your_run_id_here" # Replace with your actual run ID
13
+
14
+ def setup_mlflow():
15
+ experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
16
+ if experiment is None:
17
+ experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
18
+ else:
19
+ experiment_id = experiment.experiment_id
20
+ return experiment_id
21
+
22
+ def load_model(run_id, device):
23
+ print(f"Loading model from run: {run_id}")
24
+ model_uri = f"runs:/{run_id}/generator_model"
25
+ model = mlflow.pytorch.load_model(model_uri, map_location=device)
26
+ return model
27
+
28
+ def preprocess_image(image):
29
+ img = Image.fromarray(image).convert("RGB")
30
+ transform = transforms.Compose([
31
+ transforms.Resize((256, 256)),
32
+ transforms.ToTensor()
33
+ ])
34
+ img_tensor = transform(img)
35
+ lab_img = rgb2lab(img_tensor.permute(1, 2, 0).numpy())
36
+ L = lab_img[:,:,0]
37
+ L = (L - 50) / 50
38
+ L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float()
39
+ return L
40
+
41
+ def postprocess_output(L, ab):
42
+ L = L.squeeze().cpu().numpy()
43
+ ab = ab.squeeze().cpu().numpy()
44
+ L = (L + 1.) * 50.
45
+ ab = ab * 128.
46
+ Lab = np.concatenate([L[..., np.newaxis], ab], axis=2)
47
+ rgb_img = lab2rgb(Lab)
48
+ return (rgb_img * 255).astype(np.uint8)
49
+
50
+ def colorize_image(image, model, device):
51
+ L = preprocess_image(image).to(device)
52
+ with torch.no_grad():
53
+ ab = model(L)
54
+ colorized = postprocess_output(L, ab)
55
+ return colorized
56
+
57
+ def setup_gradio_app(run_id, device):
58
+ model = load_model(run_id, device)
59
+
60
+ def gradio_colorize(input_image):
61
+ colorized = colorize_image(input_image, model, device)
62
+ return Image.fromarray(colorized)
63
+
64
+ iface = gr.Interface(
65
+ fn=gradio_colorize,
66
+ inputs=gr.Image(label="Upload a grayscale image"),
67
+ outputs=gr.Image(label="Colorized Image"),
68
+ title="Image Colorizer",
69
+ description="Upload a grayscale image and get a colorized version!",
70
+ )
71
+
72
+ return iface
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+ from skimage.color import rgb2lab, lab2rgb
7
+ import os
8
+ from torch import nn
9
+
10
+ # Define the model architecture (same as in the training script)
11
+ class UNetBlock(nn.Module):
12
+ def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
13
+ super(UNetBlock, self).__init__()
14
+ self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \
15
+ else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
16
+ self.bn = nn.BatchNorm2d(out_channels) if bn else None
17
+ self.dropout = nn.Dropout(0.5) if dropout else None
18
+ self.down = down
19
+
20
+ def forward(self, x):
21
+ x = self.conv(x)
22
+ if self.bn:
23
+ x = self.bn(x)
24
+ if self.dropout:
25
+ x = self.dropout(x)
26
+ return nn.ReLU()(x) if self.down else nn.ReLU(inplace=True)(x)
27
+
28
+ class Generator(nn.Module):
29
+ def __init__(self):
30
+ super(Generator, self).__init__()
31
+ self.down1 = UNetBlock(1, 64, bn=False)
32
+ self.down2 = UNetBlock(64, 128)
33
+ self.down3 = UNetBlock(128, 256)
34
+ self.down4 = UNetBlock(256, 512)
35
+ self.down5 = UNetBlock(512, 512)
36
+ self.down6 = UNetBlock(512, 512)
37
+ self.down7 = UNetBlock(512, 512)
38
+ self.down8 = UNetBlock(512, 512, bn=False)
39
+
40
+ self.up1 = UNetBlock(512, 512, down=False, dropout=True)
41
+ self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
42
+ self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
43
+ self.up4 = UNetBlock(1024, 512, down=False)
44
+ self.up5 = UNetBlock(1024, 256, down=False)
45
+ self.up6 = UNetBlock(512, 128, down=False)
46
+ self.up7 = UNetBlock(256, 64, down=False)
47
+ self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1)
48
+
49
+ def forward(self, x):
50
+ d1 = self.down1(x)
51
+ d2 = self.down2(d1)
52
+ d3 = self.down3(d2)
53
+ d4 = self.down4(d3)
54
+ d5 = self.down5(d4)
55
+ d6 = self.down6(d5)
56
+ d7 = self.down7(d6)
57
+ d8 = self.down8(d7)
58
+
59
+ u1 = self.up1(d8)
60
+ u2 = self.up2(torch.cat([u1, d7], 1))
61
+ u3 = self.up3(torch.cat([u2, d6], 1))
62
+ u4 = self.up4(torch.cat([u3, d5], 1))
63
+ u5 = self.up5(torch.cat([u4, d4], 1))
64
+ u6 = self.up6(torch.cat([u5, d3], 1))
65
+ u7 = self.up7(torch.cat([u6, d2], 1))
66
+ return torch.tanh(self.up8(torch.cat([u7, d1], 1)))
67
+
68
+ # Load the checkpoint
69
+ def load_checkpoint(filename, generator, map_location):
70
+ if os.path.isfile(filename):
71
+ print(f"Loading checkpoint '{filename}'")
72
+ checkpoint = torch.load(filename, map_location=map_location)
73
+ generator.load_state_dict(checkpoint['generator_state_dict'])
74
+ print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})")
75
+ else:
76
+ print(f"No checkpoint found at '{filename}'")
77
+
78
+ # Initialize the model
79
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+ generator = Generator().to(device)
81
+ checkpoint_path = "checkpoints/latest_checkpoint.pth.tar"
82
+ load_checkpoint(checkpoint_path, generator, map_location=device)
83
+ generator.eval()
84
+
85
+ # Define the transformation
86
+ transform = transforms.Compose([
87
+ transforms.Resize((256, 256)),
88
+ transforms.Grayscale(num_output_channels=1), # Convert to grayscale
89
+ transforms.ToTensor()
90
+ ])
91
+
92
+ # Define the inference function
93
+ def colorize_image(input_image):
94
+ try:
95
+ original_size = input_image.size
96
+ input_image = transform(input_image).unsqueeze(0).to(device)
97
+ with torch.no_grad():
98
+ output = generator(input_image)
99
+ output = output.squeeze(0).cpu().numpy()
100
+ L = input_image.squeeze(0).cpu().numpy()
101
+ L = (L + 1.) * 50.
102
+ ab = output * 128.
103
+ Lab = np.concatenate([L, ab], axis=0).transpose(1, 2, 0)
104
+ rgb_image = lab2rgb(Lab)
105
+ rgb_image = Image.fromarray((rgb_image * 255).astype(np.uint8))
106
+ rgb_image = rgb_image.resize(original_size, Image.LANCZOS)
107
+ return rgb_image
108
+ except Exception as e:
109
+ print(f"Error in colorize_image: {str(e)}")
110
+ return None
111
+
112
+ # Create the Gradio interface
113
+ iface = gr.Interface(
114
+ fn=colorize_image,
115
+ inputs=gr.Image(type="pil"),
116
+ outputs=gr.Image(type="pil"),
117
+ title="Image Colorizer",
118
+ description="Upload a grayscale image to colorize it."
119
+ )
120
+
121
+ # Launch the app
122
+ if __name__ == "__main__":
123
+ iface.launch()
colorizer_pipeline.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import IterableDataset, DataLoader
6
+ from torchvision import transforms
7
+ from torchvision.utils import make_grid
8
+ import mlflow
9
+ import matplotlib.pyplot as plt
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from skimage.color import rgb2lab, lab2rgb
13
+ from datasets import load_dataset
14
+ from PIL import Image
15
+ from itertools import islice
16
+ import traceback
17
+
18
+ # MLflow setup
19
+ EXPERIMENT_NAME = "Colorizer_Experiment"
20
+
21
+ def setup_mlflow():
22
+ experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
23
+ if experiment is None:
24
+ experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
25
+ else:
26
+ experiment_id = experiment.experiment_id
27
+ return experiment_id
28
+
29
+ # Data ingestion
30
+ class ColorizeIterableDataset(IterableDataset):
31
+ def __init__(self, dataset, transform=None):
32
+ self.dataset = dataset
33
+ self.transform = transform
34
+
35
+ def __iter__(self):
36
+ for item in self.dataset:
37
+ try:
38
+ img = item['image']
39
+ if img.mode != 'RGB':
40
+ img = img.convert('RGB')
41
+ if self.transform:
42
+ img = self.transform(img)
43
+
44
+ # Add shape check after transform
45
+ if img.shape != (3, 256, 256):
46
+ print(f"Unexpected image shape after transform: {img.shape}")
47
+ continue
48
+
49
+ lab = rgb2lab(img.permute(1, 2, 0).numpy())
50
+
51
+ # Add shape check after rgb2lab conversion
52
+ if lab.shape != (256, 256, 3):
53
+ print(f"Unexpected lab shape: {lab.shape}")
54
+ continue
55
+
56
+ l_chan = lab[:, :, 0]
57
+ l_chan = (l_chan - 50) / 50
58
+ ab_chan = lab[:, :, 1:]
59
+ ab_chan = ab_chan / 128
60
+
61
+ yield torch.Tensor(l_chan).unsqueeze(0), torch.Tensor(ab_chan).permute(2, 0, 1)
62
+ except Exception as e:
63
+ print(f"Error processing image: {str(e)}")
64
+ continue
65
+
66
+ def create_dataloaders(batch_size=32):
67
+ try:
68
+ print("Loading ImageNet dataset in streaming mode...")
69
+ dataset = load_dataset("imagenet-1k", split="train", streaming=True)
70
+ print("Dataset loaded in streaming mode.")
71
+
72
+ print("Creating custom dataset...")
73
+ transform = transforms.Compose([
74
+ transforms.Resize((256, 256)), # Resize all images to 256x256
75
+ transforms.ToTensor()
76
+ ])
77
+ train_dataset = ColorizeIterableDataset(dataset, transform=transform)
78
+ print("Custom dataset created.")
79
+
80
+ print("Creating dataloader...")
81
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4)
82
+ print("Dataloader created.")
83
+
84
+ return train_dataloader
85
+ except Exception as e:
86
+ print(f"Error in create_dataloaders: {str(e)}")
87
+ return None
88
+
89
+ # Model definition
90
+ class UNetBlock(nn.Module):
91
+ def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
92
+ super(UNetBlock, self).__init__()
93
+ self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \
94
+ else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
95
+ self.bn = nn.BatchNorm2d(out_channels) if bn else None
96
+ self.dropout = nn.Dropout(0.5) if dropout else None
97
+ self.down = down
98
+
99
+ def forward(self, x):
100
+ x = self.conv(x)
101
+ if self.bn:
102
+ x = self.bn(x)
103
+ if self.dropout:
104
+ x = self.dropout(x)
105
+ return nn.ReLU()(x) if self.down else nn.ReLU(inplace=True)(x)
106
+
107
+ class Generator(nn.Module):
108
+ def __init__(self):
109
+ super(Generator, self).__init__()
110
+ self.down1 = UNetBlock(1, 64, bn=False)
111
+ self.down2 = UNetBlock(64, 128)
112
+ self.down3 = UNetBlock(128, 256)
113
+ self.down4 = UNetBlock(256, 512)
114
+ self.down5 = UNetBlock(512, 512)
115
+ self.down6 = UNetBlock(512, 512)
116
+ self.down7 = UNetBlock(512, 512)
117
+ self.down8 = UNetBlock(512, 512, bn=False)
118
+
119
+ self.up1 = UNetBlock(512, 512, down=False, dropout=True)
120
+ self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
121
+ self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
122
+ self.up4 = UNetBlock(1024, 512, down=False)
123
+ self.up5 = UNetBlock(1024, 256, down=False)
124
+ self.up6 = UNetBlock(512, 128, down=False)
125
+ self.up7 = UNetBlock(256, 64, down=False)
126
+ self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1)
127
+
128
+ def forward(self, x):
129
+ d1 = self.down1(x)
130
+ d2 = self.down2(d1)
131
+ d3 = self.down3(d2)
132
+ d4 = self.down4(d3)
133
+ d5 = self.down5(d4)
134
+ d6 = self.down6(d5)
135
+ d7 = self.down7(d6)
136
+ d8 = self.down8(d7)
137
+
138
+ u1 = self.up1(d8)
139
+ u2 = self.up2(torch.cat([u1, d7], 1))
140
+ u3 = self.up3(torch.cat([u2, d6], 1))
141
+ u4 = self.up4(torch.cat([u3, d5], 1))
142
+ u5 = self.up5(torch.cat([u4, d4], 1))
143
+ u6 = self.up6(torch.cat([u5, d3], 1))
144
+ u7 = self.up7(torch.cat([u6, d2], 1))
145
+ return torch.tanh(self.up8(torch.cat([u7, d1], 1)))
146
+
147
+ class Discriminator(nn.Module):
148
+ def __init__(self):
149
+ super(Discriminator, self).__init__()
150
+ self.model = nn.Sequential(
151
+ nn.Conv2d(3, 64, 4, stride=2, padding=1),
152
+ nn.LeakyReLU(0.2, inplace=True),
153
+ nn.Conv2d(64, 128, 4, stride=2, padding=1),
154
+ nn.BatchNorm2d(128),
155
+ nn.LeakyReLU(0.2, inplace=True),
156
+ nn.Conv2d(128, 256, 4, stride=2, padding=1),
157
+ nn.BatchNorm2d(256),
158
+ nn.LeakyReLU(0.2, inplace=True),
159
+ nn.Conv2d(256, 512, 4, padding=1),
160
+ nn.BatchNorm2d(512),
161
+ nn.LeakyReLU(0.2, inplace=True),
162
+ nn.Conv2d(512, 1, 4, padding=1)
163
+ )
164
+
165
+ def forward(self, x):
166
+ return self.model(x)
167
+
168
+ def init_weights(model):
169
+ classname = model.__class__.__name__
170
+ if classname.find('Conv') != -1:
171
+ nn.init.normal_(model.weight.data, 0.0, 0.02)
172
+ elif classname.find('BatchNorm') != -1:
173
+ nn.init.normal_(model.weight.data, 1.0, 0.02)
174
+ nn.init.constant_(model.bias.data, 0)
175
+
176
+ # Training utilities
177
+ def lab_to_rgb(L, ab):
178
+ L = (L + 1.) * 50.
179
+ ab = ab * 128.
180
+ Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
181
+ rgb_imgs = []
182
+ for img in Lab:
183
+ img_rgb = lab2rgb(img)
184
+ rgb_imgs.append(img_rgb)
185
+ return np.stack(rgb_imgs, axis=0)
186
+
187
+ def visualize_results(epoch, generator, train_loader, device):
188
+ generator.eval()
189
+ with torch.no_grad():
190
+ try:
191
+ for inputs, real_AB in train_loader:
192
+ print(f"Input shape: {inputs.shape}, real_AB shape: {real_AB.shape}")
193
+
194
+ # Ensure inputs have the correct shape (B, 1, H, W)
195
+ if inputs.shape[1] != 1:
196
+ inputs = inputs.unsqueeze(1)
197
+
198
+ inputs, real_AB = inputs.to(device), real_AB.to(device)
199
+ fake_AB = generator(inputs)
200
+
201
+ print(f"fake_AB shape: {fake_AB.shape}")
202
+
203
+ # Ensure fake_AB and real_AB have the correct shape (B, 2, H, W)
204
+ if fake_AB.shape[1] != 2:
205
+ fake_AB = fake_AB.view(fake_AB.shape[0], 2, fake_AB.shape[2], fake_AB.shape[3])
206
+ if real_AB.shape[1] != 2:
207
+ real_AB = real_AB.view(real_AB.shape[0], 2, real_AB.shape[2], real_AB.shape[3])
208
+
209
+ fake_rgb = lab_to_rgb(inputs.cpu(), fake_AB.cpu())
210
+ real_rgb = lab_to_rgb(inputs.cpu(), real_AB.cpu())
211
+
212
+ print(f"fake_rgb shape: {fake_rgb.shape}, real_rgb shape: {real_rgb.shape}")
213
+
214
+ concatenated = np.concatenate([real_rgb, fake_rgb], axis=2) # Changed axis from 3 to 2
215
+ print(f"Concatenated shape: {concatenated.shape}")
216
+
217
+ img_grid = make_grid(torch.from_numpy(concatenated).permute(0, 3, 1, 2), normalize=True, nrow=4)
218
+
219
+ plt.figure(figsize=(15, 15))
220
+ plt.imshow(img_grid.permute(1, 2, 0).cpu())
221
+ plt.axis('off')
222
+ plt.title(f'Epoch {epoch}')
223
+ plt.savefig(f'results/epoch_{epoch}.png')
224
+ mlflow.log_artifact(f'results/epoch_{epoch}.png')
225
+ plt.close()
226
+ break
227
+ except Exception as e:
228
+ print(f"Error in visualize_results: {str(e)}")
229
+ traceback.print_exc()
230
+ generator.train()
231
+
232
+ def save_checkpoint(state, filename="checkpoint.pth.tar"):
233
+ torch.save(state, filename)
234
+ mlflow.log_artifact(filename)
235
+
236
+ def load_checkpoint(filename, generator, discriminator, optimizerG, optimizerD):
237
+ if os.path.isfile(filename):
238
+ print(f"Loading checkpoint '{filename}'")
239
+ checkpoint = torch.load(filename)
240
+ start_epoch = checkpoint['epoch'] + 1
241
+ generator.load_state_dict(checkpoint['generator_state_dict'])
242
+ discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
243
+ optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
244
+ optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
245
+ print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})")
246
+ return start_epoch
247
+ else:
248
+ print(f"No checkpoint found at '{filename}'")
249
+ return 0
250
+
251
+ # Training function
252
+ def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002, beta1=0.5):
253
+ criterion = nn.BCEWithLogitsLoss()
254
+ l1_loss = nn.L1Loss()
255
+
256
+ optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
257
+ optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
258
+
259
+ checkpoint_dir = "checkpoints"
260
+ os.makedirs(checkpoint_dir, exist_ok=True)
261
+ os.makedirs("results", exist_ok=True)
262
+
263
+ checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth.tar")
264
+ start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, optimizerG, optimizerD)
265
+
266
+ experiment_id = setup_mlflow()
267
+ with mlflow.start_run(experiment_id=experiment_id, run_name="training_run") as run:
268
+ try:
269
+ for epoch in range(start_epoch, num_epochs):
270
+ generator.train()
271
+ discriminator.train()
272
+
273
+ num_iterations = 2
274
+ pbar = tqdm(enumerate(islice(train_loader, num_iterations)), total=num_iterations, desc=f"Epoch {epoch+1}/{num_epochs}")
275
+
276
+ for i, (real_L, real_AB) in pbar:
277
+ # Add shape check
278
+ if real_L.shape[1:] != (1, 256, 256) or real_AB.shape[1:] != (2, 256, 256):
279
+ print(f"Unexpected tensor shapes: real_L {real_L.shape}, real_AB {real_AB.shape}")
280
+ continue
281
+
282
+ real_L, real_AB = real_L.to(device), real_AB.to(device)
283
+ batch_size = real_L.size(0)
284
+
285
+ # Train Discriminator
286
+ optimizerD.zero_grad()
287
+
288
+ fake_AB = generator(real_L)
289
+ fake_LAB = torch.cat([real_L, fake_AB], dim=1)
290
+ real_LAB = torch.cat([real_L, real_AB], dim=1)
291
+
292
+ pred_fake = discriminator(fake_LAB.detach())
293
+ loss_D_fake = criterion(pred_fake, torch.zeros_like(pred_fake))
294
+
295
+ pred_real = discriminator(real_LAB)
296
+ loss_D_real = criterion(pred_real, torch.ones_like(pred_real))
297
+
298
+ loss_D = (loss_D_fake + loss_D_real) * 0.5
299
+ loss_D.backward()
300
+ optimizerD.step()
301
+
302
+ # Train Generator
303
+ optimizerG.zero_grad()
304
+
305
+ fake_AB = generator(real_L)
306
+ fake_LAB = torch.cat([real_L, fake_AB], dim=1)
307
+ pred_fake = discriminator(fake_LAB)
308
+
309
+ loss_G_GAN = criterion(pred_fake, torch.ones_like(pred_fake))
310
+ loss_G_L1 = l1_loss(fake_AB, real_AB) * 100 # L1 loss weight
311
+
312
+ loss_G = loss_G_GAN + loss_G_L1
313
+ loss_G.backward()
314
+ optimizerG.step()
315
+
316
+ pbar.set_postfix({
317
+ 'D_loss': loss_D.item(),
318
+ 'G_loss': loss_G.item(),
319
+ 'G_L1': loss_G_L1.item()
320
+ })
321
+
322
+ mlflow.log_metrics({
323
+ "D_loss": loss_D.item(),
324
+ "G_loss": loss_G.item(),
325
+ "G_L1_loss": loss_G_L1.item()
326
+ }, step=epoch * num_iterations + i)
327
+
328
+ visualize_results(epoch, generator, train_loader, device)
329
+
330
+ checkpoint = {
331
+ 'epoch': epoch,
332
+ 'generator_state_dict': generator.state_dict(),
333
+ 'discriminator_state_dict': discriminator.state_dict(),
334
+ 'optimizerG_state_dict': optimizerG.state_dict(),
335
+ 'optimizerD_state_dict': optimizerD.state_dict(),
336
+ }
337
+ save_checkpoint(checkpoint, filename=checkpoint_path)
338
+
339
+ print("Training completed successfully.")
340
+
341
+ mlflow.pytorch.log_model(generator, "generator_model")
342
+ model_uri = f"runs:/{run.info.run_id}/generator_model"
343
+ mlflow.register_model(model_uri, "colorizer_generator")
344
+
345
+ return run.info.run_id
346
+
347
+ except Exception as e:
348
+ print(f"Error during training: {str(e)}")
349
+ mlflow.log_param("error", str(e))
350
+ return None
351
+
352
+ # Main execution
353
+ if __name__ == "__main__":
354
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
355
+ print(f"Using device: {device}")
356
+
357
+ try:
358
+ batch_size = 32
359
+ num_epochs = 50
360
+
361
+ train_loader = create_dataloaders(batch_size=batch_size)
362
+ if train_loader is None:
363
+ raise Exception("Failed to create dataloader")
364
+
365
+ generator = Generator().to(device)
366
+ discriminator = Discriminator().to(device)
367
+
368
+ generator.apply(init_weights)
369
+ discriminator.apply(init_weights)
370
+
371
+ run_id = train(generator, discriminator, train_loader, num_epochs=num_epochs, device=device)
372
+ if run_id:
373
+ print(f"Training completed successfully. Run ID: {run_id}")
374
+ with open("latest_run_id.txt", "w") as f:
375
+ f.write(run_id)
376
+ else:
377
+ print("Training failed!")
378
+ except Exception as e:
379
+ print(f"Critical error in main execution: {str(e)}")
data_ingestion.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import mlflow
3
+ import torch
4
+ from torch.utils.data import IterableDataset, DataLoader
5
+ from torchvision import transforms
6
+ from datasets import load_dataset
7
+ from skimage.color import rgb2lab
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+
12
+ class ColorizeIterableDataset(IterableDataset):
13
+ def __init__(self, dataset):
14
+ self.dataset = dataset
15
+ self.transform = transforms.Compose([
16
+ transforms.Resize((256, 256)),
17
+ transforms.ToTensor()
18
+ ])
19
+
20
+ def __iter__(self):
21
+ for item in self.dataset:
22
+ try:
23
+ img = item['image']
24
+ if img.mode != 'RGB':
25
+ img = img.convert('RGB')
26
+ img = self.transform(img)
27
+
28
+ # Convert to LAB color space
29
+ lab = rgb2lab(img.permute(1, 2, 0).numpy())
30
+
31
+ # Normalize L channel to range [-1, 1]
32
+ l_chan = lab[:, :, 0]
33
+ l_chan = (l_chan - 50) / 50
34
+
35
+ # Normalize AB channels to range [-1, 1]
36
+ ab_chan = lab[:, :, 1:]
37
+ ab_chan = ab_chan / 128
38
+
39
+ yield torch.Tensor(l_chan).unsqueeze(0), torch.Tensor(ab_chan).permute(2, 0, 1)
40
+ except Exception as e:
41
+ print(f"Error processing image: {str(e)}")
42
+ continue
43
+
44
+ def create_dataloaders(batch_size=32):
45
+ try:
46
+ print("Loading ImageNet dataset in streaming mode...")
47
+ # Load ImageNet dataset from Hugging Face in streaming mode
48
+ dataset = load_dataset("imagenet-1k", split="train", streaming=True)
49
+ print("Dataset loaded in streaming mode.")
50
+
51
+ print("Creating custom dataset...")
52
+ # Create custom dataset
53
+ train_dataset = ColorizeIterableDataset(dataset)
54
+ print("Custom dataset created.")
55
+
56
+ print("Creating dataloader...")
57
+ # Create dataloader
58
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4)
59
+ print("Dataloader created.")
60
+
61
+ return train_dataloader
62
+ except Exception as e:
63
+ print(f"Error in create_dataloaders: {str(e)}")
64
+ return None
65
+
66
+ def test_data_ingestion():
67
+ print("Testing data ingestion...")
68
+ try:
69
+ dataloader = create_dataloaders(batch_size=4)
70
+ if dataloader is None:
71
+ raise Exception("Dataloader creation failed")
72
+
73
+ # Get the first batch
74
+ for sample_batch in dataloader:
75
+ if len(sample_batch) != 2:
76
+ raise Exception(f"Unexpected batch format: {len(sample_batch)} elements instead of 2")
77
+
78
+ l_chan, ab_chan = sample_batch
79
+ if l_chan.shape != torch.Size([4, 1, 256, 256]) or ab_chan.shape != torch.Size([4, 2, 256, 256]):
80
+ raise Exception(f"Unexpected tensor shapes: L={l_chan.shape}, AB={ab_chan.shape}")
81
+
82
+ print("Data ingestion test passed.")
83
+ return True
84
+ except Exception as e:
85
+ print(f"Data ingestion test failed: {str(e)}")
86
+ return False
87
+
88
+ if __name__ == "__main__":
89
+ try:
90
+ print("Starting data ingestion pipeline...")
91
+ mlflow.start_run(run_name="data_ingestion")
92
+
93
+ try:
94
+ # Log parameters
95
+ print("Logging parameters...")
96
+ mlflow.log_param("batch_size", 32)
97
+ mlflow.log_param("dataset", "imagenet-1k")
98
+ print("Parameters logged.")
99
+
100
+ # Create dataloaders
101
+ print("Creating dataloaders...")
102
+ train_dataloader = create_dataloaders(batch_size=32)
103
+ if train_dataloader is None:
104
+ raise Exception("Failed to create dataloader")
105
+ print("Dataloaders created successfully.")
106
+
107
+ # Log a sample batch
108
+ print("Logging sample batch...")
109
+ for sample_batch in train_dataloader:
110
+ l_chan, ab_chan = sample_batch
111
+
112
+ # Log sample input (L channel)
113
+ sample_input = l_chan[0].numpy()
114
+ mlflow.log_image(sample_input, "sample_input_l_channel.png")
115
+
116
+ # Log sample target (AB channels)
117
+ sample_target = ab_chan[0].permute(1, 2, 0).numpy()
118
+ mlflow.log_image(sample_target, "sample_target_ab_channels.png")
119
+
120
+ print("Sample batch logged.")
121
+ break # We only need one batch for logging
122
+
123
+ print("Data ingestion pipeline completed successfully.")
124
+
125
+ except Exception as e:
126
+ print(f"Error in data ingestion pipeline: {str(e)}")
127
+ mlflow.log_param("error", str(e))
128
+
129
+ finally:
130
+ mlflow.end_run()
131
+
132
+ except Exception as e:
133
+ print(f"Critical error in main execution: {str(e)}")
134
+
135
+ if test_data_ingestion():
136
+ print("All tests passed.")
137
+ else:
138
+ print("Tests failed.")
inference.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import mlflow
4
+ import mlflow.pytorch
5
+ from PIL import Image
6
+ import numpy as np
7
+ from skimage.color import rgb2lab, lab2rgb
8
+ from torchvision import transforms
9
+ import argparse
10
+
11
+ from model import Generator
12
+
13
+ EXPERIMENT_NAME = "Colorizer_Experiment"
14
+
15
+ def setup_mlflow():
16
+ experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
17
+ if experiment is None:
18
+ experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
19
+ else:
20
+ experiment_id = experiment.experiment_id
21
+ return experiment_id
22
+
23
+ def load_model(run_id, device):
24
+ print(f"Loading model from run: {run_id}")
25
+ model_uri = f"runs:/{run_id}/generator_model"
26
+ model = mlflow.pytorch.load_model(model_uri, map_location=device)
27
+ return model
28
+
29
+ # Configuration variables
30
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ RUN_ID = "your_run_id_here" # Replace with the actual run ID
32
+ IMAGE_PATH = "path/to/your/image.jpg" # Replace with the path to your input image
33
+ SAVE_MODEL = False
34
+ SERVE_MODEL = False
35
+ SERVE_PORT = 5000
36
+
37
+ def preprocess_image(image_path):
38
+ img = Image.open(image_path).convert("RGB")
39
+ transform = transforms.Compose([
40
+ transforms.Resize((256, 256)),
41
+ transforms.ToTensor()
42
+ ])
43
+ img_tensor = transform(img)
44
+ lab_img = rgb2lab(img_tensor.permute(1, 2, 0).numpy())
45
+ L = lab_img[:,:,0]
46
+ L = (L - 50) / 50
47
+ L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float()
48
+ return L
49
+
50
+ def postprocess_output(L, ab):
51
+ L = L.squeeze().cpu().numpy()
52
+ ab = ab.squeeze().cpu().numpy()
53
+ L = (L + 1.) * 50.
54
+ ab = ab * 128.
55
+ Lab = np.concatenate([L[..., np.newaxis], ab], axis=2)
56
+ rgb_img = lab2rgb(Lab)
57
+ return (rgb_img * 255).astype(np.uint8)
58
+
59
+ def colorize_image(model, image_path, device):
60
+ L = preprocess_image(image_path).to(device)
61
+ with torch.no_grad():
62
+ ab = model(L)
63
+ colorized = postprocess_output(L, ab)
64
+ return colorized
65
+
66
+ def save_model(model, run_id):
67
+ with mlflow.start_run(run_id=run_id):
68
+ # Log the model
69
+ mlflow.pytorch.log_model(model, "model")
70
+
71
+ # Register the model
72
+ model_uri = f"runs:/{run_id}/model"
73
+ mlflow.register_model(model_uri, "colorizer_model")
74
+
75
+ print(f"Model saved and registered with run_id: {run_id}")
76
+
77
+ def serve_model(run_id, port=5000):
78
+ model_uri = f"runs:/{run_id}/model"
79
+ mlflow.pytorch.serve(model_uri, port=port)
80
+
81
+ if __name__ == "__main__":
82
+ parser = argparse.ArgumentParser(description="Colorizer Inference")
83
+ parser.add_argument("--run_id", type=str, help="MLflow run ID of the trained model")
84
+ parser.add_argument("--image_path", type=str, required=True, help="Path to the input grayscale image")
85
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
86
+ help="Device to use for inference (cuda/cpu)")
87
+ args = parser.parse_args()
88
+
89
+ device = torch.device(args.device)
90
+ print(f"Using device: {device}")
91
+
92
+ # If run_id is not provided, try to load it from the file
93
+ if not args.run_id:
94
+ try:
95
+ with open("latest_run_id.txt", "r") as f:
96
+ args.run_id = f.read().strip()
97
+ except FileNotFoundError:
98
+ print("No run ID provided and couldn't find latest_run_id.txt")
99
+ exit(1)
100
+
101
+ experiment_id = setup_mlflow()
102
+ with mlflow.start_run(experiment_id=experiment_id, run_name="inference_run"):
103
+ try:
104
+ model = load_model(args.run_id, device)
105
+ colorized = colorize_image(model, args.image_path, device)
106
+ output_path = f"colorized_{os.path.basename(args.image_path)}"
107
+ Image.fromarray(colorized).save(output_path)
108
+ print(f"Colorized image saved as: {output_path}")
109
+
110
+ mlflow.log_artifact(output_path)
111
+ mlflow.log_param("input_image", args.image_path)
112
+ mlflow.log_param("model_run_id", args.run_id)
113
+ except Exception as e:
114
+ print(f"Error during inference: {str(e)}")
115
+ mlflow.log_param("error", str(e))
116
+ finally:
117
+ mlflow.end_run()
instructions.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Create a Colorizer webapp. This is a professional project and should follow all necessary coding practices, machine learning as well as MLops practices. Use python, mlflow, etc as well as other libraries and tools of your choice which might be required to achieve a fully functioning app. Can deploy it on gradio.
2
+ 1. Gather a large and diverse dataset of color images. Convert these images to grayscale to create the paried training dataset for the Pix2Pix model.
3
+ 2. Implement a Pix2Pix GAN architecture, utilizing U net generator for colorization and CNN discriminator for realism evaluation. Ensure the model processes images in the LAB color space.
4
+ 3. Define the loss functions for training: l1 loss for pixel accuracy and adversarial loss for generating realistic images.
5
+ 4. Train the Pix2Pix model using grayscale images as input and color images as targets. Monitor validation performance and apply data augmentation techniques.
6
+ 5. Evaluate the trained model on the test dataset using PSNR and SSIM metrics. generate and visually inspect colorized images.
7
+ 5. Develop a user interface for application that allows users to upload images. use gradio for this. Integrate the trained model and ensure the output maintains the original image size and quality.
8
+ 6. Add post processing features to enhance the quality of colorized images. implement a feedback mechanism for users.
9
+ 7. Perform user testing to gather feedback on application. Use this feedback to iterate and improve the model and user interface.
10
+
11
+ This is a professional project and should follow all necessary coding practices, machine learning as well as MLops practices. Use python, mlflow, etc as well as other libraries and tools of your choice which might be required to achieve a fully functioning app.
12
+ All the functions should be completely implemented and working.
13
+ The project should be containerized and should be able to run in a docker container.
model.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class UNetBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
7
+ super(UNetBlock, self).__init__()
8
+ self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \
9
+ else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
10
+ self.bn = nn.BatchNorm2d(out_channels) if bn else None
11
+ self.dropout = nn.Dropout(0.5) if dropout else None
12
+ self.down = down
13
+
14
+ def forward(self, x):
15
+ x = self.conv(x)
16
+ if self.bn:
17
+ x = self.bn(x)
18
+ if self.dropout:
19
+ x = self.dropout(x)
20
+ return F.relu(x) if self.down else F.relu(x, inplace=True)
21
+
22
+ class Generator(nn.Module):
23
+ def __init__(self):
24
+ super(Generator, self).__init__()
25
+ self.down1 = UNetBlock(1, 64, bn=False) # Input is L channel (1 channel)
26
+ self.down2 = UNetBlock(64, 128)
27
+ self.down3 = UNetBlock(128, 256)
28
+ self.down4 = UNetBlock(256, 512)
29
+ self.down5 = UNetBlock(512, 512)
30
+ self.down6 = UNetBlock(512, 512)
31
+ self.down7 = UNetBlock(512, 512)
32
+ self.down8 = UNetBlock(512, 512, bn=False)
33
+
34
+ self.up1 = UNetBlock(512, 512, down=False, dropout=True)
35
+ self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
36
+ self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
37
+ self.up4 = UNetBlock(1024, 512, down=False)
38
+ self.up5 = UNetBlock(1024, 256, down=False)
39
+ self.up6 = UNetBlock(512, 128, down=False)
40
+ self.up7 = UNetBlock(256, 64, down=False)
41
+ self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1) # Output is AB channels (2 channels)
42
+
43
+ def forward(self, x):
44
+ d1 = self.down1(x)
45
+ d2 = self.down2(d1)
46
+ d3 = self.down3(d2)
47
+ d4 = self.down4(d3)
48
+ d5 = self.down5(d4)
49
+ d6 = self.down6(d5)
50
+ d7 = self.down7(d6)
51
+ d8 = self.down8(d7)
52
+
53
+ u1 = self.up1(d8)
54
+ u2 = self.up2(torch.cat([u1, d7], 1))
55
+ u3 = self.up3(torch.cat([u2, d6], 1))
56
+ u4 = self.up4(torch.cat([u3, d5], 1))
57
+ u5 = self.up5(torch.cat([u4, d4], 1))
58
+ u6 = self.up6(torch.cat([u5, d3], 1))
59
+ u7 = self.up7(torch.cat([u6, d2], 1))
60
+ return torch.tanh(self.up8(torch.cat([u7, d1], 1)))
61
+
62
+ class Discriminator(nn.Module):
63
+ def __init__(self):
64
+ super(Discriminator, self).__init__()
65
+ self.model = nn.Sequential(
66
+ nn.Conv2d(3, 64, 4, stride=2, padding=1), # Input is L+AB (3 channels)
67
+ nn.LeakyReLU(0.2, inplace=True),
68
+ nn.Conv2d(64, 128, 4, stride=2, padding=1),
69
+ nn.BatchNorm2d(128),
70
+ nn.LeakyReLU(0.2, inplace=True),
71
+ nn.Conv2d(128, 256, 4, stride=2, padding=1),
72
+ nn.BatchNorm2d(256),
73
+ nn.LeakyReLU(0.2, inplace=True),
74
+ nn.Conv2d(256, 512, 4, padding=1),
75
+ nn.BatchNorm2d(512),
76
+ nn.LeakyReLU(0.2, inplace=True),
77
+ nn.Conv2d(512, 1, 4, padding=1)
78
+ )
79
+
80
+ def forward(self, x):
81
+ return self.model(x)
82
+
83
+ def init_weights(model):
84
+ classname = model.__class__.__name__
85
+ if classname.find('Conv') != -1:
86
+ nn.init.normal_(model.weight.data, 0.0, 0.02)
87
+ elif classname.find('BatchNorm') != -1:
88
+ nn.init.normal_(model.weight.data, 1.0, 0.02)
89
+ nn.init.constant_(model.bias.data, 0)
90
+
91
+ def create_models():
92
+ try:
93
+ print("Creating Generator...")
94
+ generator = Generator()
95
+ generator.apply(init_weights)
96
+ print("Generator created successfully.")
97
+
98
+ print("Creating Discriminator...")
99
+ discriminator = Discriminator()
100
+ discriminator.apply(init_weights)
101
+ print("Discriminator created successfully.")
102
+
103
+ return generator, discriminator
104
+ except Exception as e:
105
+ print(f"Error in creating models: {str(e)}")
106
+ return None, None
107
+
108
+ def test_models():
109
+ print("Testing models...")
110
+ try:
111
+ generator, discriminator = create_models()
112
+ if generator is None or discriminator is None:
113
+ raise Exception("Model creation failed")
114
+
115
+ test_input_g = torch.randn(1, 1, 256, 256)
116
+ test_output_g = generator(test_input_g)
117
+ if test_output_g.shape != torch.Size([1, 2, 256, 256]):
118
+ raise Exception(f"Unexpected generator output shape: {test_output_g.shape}")
119
+
120
+ test_input_d = torch.randn(1, 3, 256, 256)
121
+ test_output_d = discriminator(test_input_d)
122
+ if test_output_d.shape != torch.Size([1, 1, 30, 30]):
123
+ raise Exception(f"Unexpected discriminator output shape: {test_output_d.shape}")
124
+
125
+ print("Model test passed.")
126
+ return True
127
+ except Exception as e:
128
+ print(f"Model test failed: {str(e)}")
129
+ return False
130
+
131
+ if __name__ == "__main__":
132
+ try:
133
+ print("Initializing models...")
134
+ generator, discriminator = create_models()
135
+
136
+ if generator is None or discriminator is None:
137
+ raise Exception("Failed to create models")
138
+
139
+ if not test_models():
140
+ raise Exception("Model testing failed")
141
+
142
+ print("Model creation and testing completed successfully.")
143
+ except Exception as e:
144
+ print(f"Critical error in main execution: {str(e)}")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ mlflow
5
+ numpy
6
+ Pillow
7
+ scikit-image
run_colorizer.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import mlflow
5
+ from data_ingestion import create_dataloaders, test_data_ingestion
6
+ from model import Generator, Discriminator, init_weights, test_models
7
+ from train import train, test_training
8
+ from app import setup_gradio_app
9
+
10
+ EXPERIMENT_NAME = "Colorizer_Experiment"
11
+
12
+ def setup_mlflow():
13
+ experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
14
+ if experiment is None:
15
+ experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
16
+ else:
17
+ experiment_id = experiment.experiment_id
18
+ return experiment_id
19
+
20
+ def run_pipeline(args):
21
+ device = torch.device(args.device)
22
+ print(f"Using device: {device}")
23
+
24
+ experiment_id = setup_mlflow()
25
+
26
+ if args.ingest_data or args.run_all:
27
+ print("Starting data ingestion...")
28
+ train_loader = create_dataloaders(batch_size=args.batch_size)
29
+ if train_loader is None:
30
+ print("Data ingestion failed.")
31
+ return
32
+ else:
33
+ train_loader = None
34
+
35
+ if args.create_model or args.train or args.run_all:
36
+ print("Creating and testing models...")
37
+ generator = Generator().to(device)
38
+ discriminator = Discriminator().to(device)
39
+ generator.apply(init_weights)
40
+ discriminator.apply(init_weights)
41
+ if not test_models():
42
+ print("Model creation or testing failed.")
43
+ return
44
+ else:
45
+ generator = None
46
+ discriminator = None
47
+
48
+ if args.train or args.run_all:
49
+ print("Starting model training...")
50
+ if train_loader is None:
51
+ print("Creating dataloader for training...")
52
+ train_loader = create_dataloaders(batch_size=args.batch_size)
53
+ if train_loader is None:
54
+ print("Failed to create dataloader for training.")
55
+ return
56
+ if generator is None or discriminator is None:
57
+ print("Creating models for training...")
58
+ generator = Generator().to(device)
59
+ discriminator = Discriminator().to(device)
60
+ generator.apply(init_weights)
61
+ discriminator.apply(init_weights)
62
+ run_id = train(generator, discriminator, train_loader, num_epochs=args.num_epochs, device=device)
63
+ if run_id:
64
+ print(f"Training completed. Run ID: {run_id}")
65
+ with open("latest_run_id.txt", "w") as f:
66
+ f.write(run_id)
67
+ else:
68
+ print("Training failed.")
69
+ return
70
+
71
+ if args.test_training:
72
+ print("Testing training process...")
73
+ if train_loader is None:
74
+ print("Creating dataloader for testing...")
75
+ train_loader = create_dataloaders(batch_size=args.batch_size)
76
+ if train_loader is None:
77
+ print("Failed to create dataloader for testing.")
78
+ return
79
+ if generator is None or discriminator is None:
80
+ print("Creating models for testing...")
81
+ generator = Generator().to(device)
82
+ discriminator = Discriminator().to(device)
83
+ generator.apply(init_weights)
84
+ discriminator.apply(init_weights)
85
+ if test_training(generator, discriminator, train_loader, device):
86
+ print("Training process test passed.")
87
+ else:
88
+ print("Training process test failed.")
89
+
90
+ if args.serve or args.run_all:
91
+ print("Setting up Gradio app for serving...")
92
+ if not args.run_id:
93
+ try:
94
+ with open("latest_run_id.txt", "r") as f:
95
+ args.run_id = f.read().strip()
96
+ except FileNotFoundError:
97
+ print("No run ID provided and couldn't find latest_run_id.txt")
98
+ return
99
+ iface = setup_gradio_app(args.run_id, device)
100
+ iface.launch(share=args.share)
101
+
102
+ if __name__ == "__main__":
103
+ parser = argparse.ArgumentParser(description="Run Colorizer Pipeline")
104
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
105
+ help="Device to use (cuda/cpu)")
106
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")
107
+ parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs to train")
108
+ parser.add_argument("--run_id", type=str, help="MLflow run ID of the trained model for inference")
109
+ parser.add_argument("--ingest_data", action="store_true", help="Run data ingestion")
110
+ parser.add_argument("--create_model", action="store_true", help="Create and test the model")
111
+ parser.add_argument("--train", action="store_true", help="Train the model")
112
+ parser.add_argument("--test_training", action="store_true", help="Test the training process")
113
+ parser.add_argument("--serve", action="store_true", help="Serve the model using Gradio")
114
+ parser.add_argument("--run_all", action="store_true", help="Run all steps")
115
+ parser.add_argument("--share", action="store_true", help="Share the Gradio app publicly")
116
+ args = parser.parse_args()
117
+
118
+ run_pipeline(args)
train.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from torchvision.utils import make_grid
7
+ import mlflow
8
+ import matplotlib.pyplot as plt
9
+ from tqdm import tqdm
10
+ import numpy as np
11
+ from skimage.color import lab2rgb, rgb2lab
12
+ import argparse
13
+ from itertools import islice
14
+ from PIL import Image
15
+ import torchvision.transforms as transforms
16
+
17
+ from data_ingestion import ColorizeIterableDataset, create_dataloaders
18
+ from model import Generator, Discriminator, init_weights
19
+
20
+ EXPERIMENT_NAME = "Colorizer_Experiment"
21
+
22
+ def setup_mlflow():
23
+ experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
24
+ if experiment is None:
25
+ experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
26
+ else:
27
+ experiment_id = experiment.experiment_id
28
+ return experiment_id
29
+
30
+ def lab_to_rgb(L, ab):
31
+ """Convert L and ab channels to RGB image"""
32
+ L = (L + 1.) * 50.
33
+ ab = ab * 128.
34
+ Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
35
+ rgb_imgs = []
36
+ for img in Lab:
37
+ img_rgb = lab2rgb(img)
38
+ rgb_imgs.append(img_rgb)
39
+ return np.stack(rgb_imgs, axis=0)
40
+
41
+ def preprocess_image(image_path):
42
+ img = Image.open(image_path).convert('RGB')
43
+ img = img.resize((256, 256)) # Resize to a consistent size
44
+ img_lab = rgb2lab(img)
45
+ img_lab = (img_lab + [0, 128, 128]) / [100, 255, 255] # Normalize LAB values
46
+ return img_lab[:,:,0], img_lab[:,:,1:]
47
+
48
+ def visualize_results(epoch, generator, train_loader, device):
49
+ generator.eval()
50
+ with torch.no_grad():
51
+ for inputs, real_AB in train_loader:
52
+ inputs, real_AB = inputs.to(device), real_AB.to(device)
53
+ fake_AB = generator(inputs)
54
+
55
+ fake_rgb = lab_to_rgb(inputs.cpu(), fake_AB.cpu())
56
+ real_rgb = lab_to_rgb(inputs.cpu(), real_AB.cpu())
57
+
58
+ img_grid = make_grid(torch.from_numpy(np.concatenate([real_rgb, fake_rgb], axis=3)).permute(0, 3, 1, 2), normalize=True, nrow=4)
59
+
60
+ plt.figure(figsize=(15, 15))
61
+ plt.imshow(img_grid.permute(1, 2, 0).cpu())
62
+ plt.axis('off')
63
+ plt.title(f'Epoch {epoch}')
64
+ plt.savefig(f'results/epoch_{epoch}.png')
65
+ mlflow.log_artifact(f'results/epoch_{epoch}.png')
66
+ plt.close()
67
+ break # Only visualize one batch
68
+ generator.train()
69
+
70
+ def save_checkpoint(state, filename="checkpoint.pth.tar"):
71
+ torch.save(state, filename)
72
+ mlflow.log_artifact(filename)
73
+
74
+ def load_checkpoint(filename, generator, discriminator, optimizerG, optimizerD):
75
+ if os.path.isfile(filename):
76
+ print(f"Loading checkpoint '{filename}'")
77
+ checkpoint = torch.load(filename)
78
+ start_epoch = checkpoint['epoch'] + 1
79
+ generator.load_state_dict(checkpoint['generator_state_dict'])
80
+ discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
81
+ optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
82
+ optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
83
+ print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})")
84
+ return start_epoch
85
+ else:
86
+ print(f"No checkpoint found at '{filename}'")
87
+ return 0
88
+
89
+ def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002, beta1=0.5):
90
+ criterion = nn.BCEWithLogitsLoss()
91
+ l1_loss = nn.L1Loss()
92
+
93
+ optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
94
+ optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
95
+
96
+ checkpoint_dir = "checkpoints"
97
+ os.makedirs(checkpoint_dir, exist_ok=True)
98
+ os.makedirs("results", exist_ok=True)
99
+
100
+ checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth.tar")
101
+ start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, optimizerG, optimizerD)
102
+
103
+ experiment_id = setup_mlflow()
104
+ with mlflow.start_run(experiment_id=experiment_id, run_name="training_run") as run:
105
+ try:
106
+ for epoch in range(start_epoch, num_epochs):
107
+ generator.train()
108
+ discriminator.train()
109
+
110
+ # Use a fixed number of iterations per epoch
111
+ num_iterations = 1000
112
+ pbar = tqdm(enumerate(islice(train_loader, num_iterations)), total=num_iterations, desc=f"Epoch {epoch+1}/{num_epochs}")
113
+
114
+ for i, (real_L, real_AB) in pbar:
115
+ real_L, real_AB = real_L.to(device), real_AB.to(device)
116
+ batch_size = real_L.size(0)
117
+
118
+ # Train Discriminator
119
+ optimizerD.zero_grad()
120
+
121
+ fake_AB = generator(real_L)
122
+ fake_LAB = torch.cat([real_L, fake_AB], dim=1)
123
+ real_LAB = torch.cat([real_L, real_AB], dim=1)
124
+
125
+ pred_fake = discriminator(fake_LAB.detach())
126
+ loss_D_fake = criterion(pred_fake, torch.zeros_like(pred_fake))
127
+
128
+ pred_real = discriminator(real_LAB)
129
+ loss_D_real = criterion(pred_real, torch.ones_like(pred_real))
130
+
131
+ loss_D = (loss_D_fake + loss_D_real) * 0.5
132
+ loss_D.backward()
133
+ optimizerD.step()
134
+
135
+ # Train Generator
136
+ optimizerG.zero_grad()
137
+
138
+ fake_AB = generator(real_L)
139
+ fake_LAB = torch.cat([real_L, fake_AB], dim=1)
140
+ pred_fake = discriminator(fake_LAB)
141
+
142
+ loss_G_GAN = criterion(pred_fake, torch.ones_like(pred_fake))
143
+ loss_G_L1 = l1_loss(fake_AB, real_AB) * 100 # L1 loss weight
144
+
145
+ loss_G = loss_G_GAN + loss_G_L1
146
+ loss_G.backward()
147
+ optimizerG.step()
148
+
149
+ pbar.set_postfix({
150
+ 'D_loss': loss_D.item(),
151
+ 'G_loss': loss_G.item(),
152
+ 'G_L1': loss_G_L1.item()
153
+ })
154
+
155
+ mlflow.log_metrics({
156
+ "D_loss": loss_D.item(),
157
+ "G_loss": loss_G.item(),
158
+ "G_L1_loss": loss_G_L1.item()
159
+ }, step=epoch * num_iterations + i)
160
+
161
+ visualize_results(epoch, generator, train_loader, device)
162
+
163
+ checkpoint = {
164
+ 'epoch': epoch,
165
+ 'generator_state_dict': generator.state_dict(),
166
+ 'discriminator_state_dict': discriminator.state_dict(),
167
+ 'optimizerG_state_dict': optimizerG.state_dict(),
168
+ 'optimizerD_state_dict': optimizerD.state_dict(),
169
+ }
170
+ save_checkpoint(checkpoint, filename=checkpoint_path)
171
+
172
+ print("Training completed successfully.")
173
+
174
+ # Log the generator model
175
+ mlflow.pytorch.log_model(generator, "generator_model")
176
+
177
+ # Register the model
178
+ model_uri = f"runs:/{run.info.run_id}/generator_model"
179
+ mlflow.register_model(model_uri, "colorizer_generator")
180
+
181
+ return run.info.run_id
182
+
183
+ except Exception as e:
184
+ print(f"Error during training: {str(e)}")
185
+ mlflow.log_param("error", str(e))
186
+ return None
187
+
188
+ def test_training(generator, discriminator, train_loader, device):
189
+ print("Testing training process...")
190
+ try:
191
+ train(generator, discriminator, train_loader, num_epochs=1, device=device)
192
+ print("Training process test passed.")
193
+ return True
194
+ except Exception as e:
195
+ print(f"Training process test failed: {str(e)}")
196
+ return False
197
+
198
+ if __name__ == "__main__":
199
+ parser = argparse.ArgumentParser(description="Train Colorizer model")
200
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
201
+ help="Device to use for training (cuda/cpu)")
202
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")
203
+ parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs to train")
204
+ parser.add_argument("--test", action="store_true", help="Run in test mode")
205
+ args = parser.parse_args()
206
+
207
+ device = torch.device(args.device)
208
+ print(f"Using device: {device}")
209
+
210
+ try:
211
+ train_loader = create_dataloaders(batch_size=args.batch_size)
212
+
213
+ generator = Generator().to(device)
214
+ discriminator = Discriminator().to(device)
215
+
216
+ generator.apply(init_weights)
217
+ discriminator.apply(init_weights)
218
+
219
+ if args.test:
220
+ if test_training(generator, discriminator, train_loader, device):
221
+ print("All tests passed.")
222
+ else:
223
+ print("Tests failed.")
224
+ else:
225
+ run_id = train(generator, discriminator, train_loader, num_epochs=args.num_epochs, device=device)
226
+ if run_id:
227
+ print(f"Training completed. Run ID: {run_id}")
228
+ # Save the run ID to a file for easy access by the inference script
229
+ with open("latest_run_id.txt", "w") as f:
230
+ f.write(run_id)
231
+ else:
232
+ print("Training failed.")
233
+
234
+ except Exception as e:
235
+ print(f"Critical error in main execution: {str(e)}")