Rohil Bansal
commited on
Commit
•
02f3f24
1
Parent(s):
8e66116
huggingface spaces commit.
Browse files- .gitattributes +1 -35
- .gitignore +7 -0
- Dockerfile +10 -0
- README.md +11 -12
- app-mlflow.py +72 -0
- app.py +123 -0
- colorizer_pipeline.py +379 -0
- data_ingestion.py +138 -0
- inference.py +117 -0
- instructions.txt +13 -0
- model.py +144 -0
- requirements.txt +7 -0
- run_colorizer.py +118 -0
- train.py +235 -0
.gitattributes
CHANGED
@@ -1,35 +1 @@
|
|
1 |
-
*.
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
.env
|
3 |
+
__pycache__
|
4 |
+
*.pth.tar
|
5 |
+
results/
|
6 |
+
mlruns/
|
7 |
+
# checkpoints/
|
Dockerfile
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY requirements.txt .
|
6 |
+
RUN pip install -r requirements.txt
|
7 |
+
|
8 |
+
COPY . .
|
9 |
+
|
10 |
+
CMD ["python", "data_ingestion.py"]
|
README.md
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
# Image Colorizer
|
2 |
+
|
3 |
+
This is a Gradio app that colorizes grayscale images using a trained GAN model. Upload a grayscale image, and the app will return a colorized version of the image.
|
4 |
+
|
5 |
+
## How to use
|
6 |
+
|
7 |
+
1. Upload a grayscale image using the provided interface.
|
8 |
+
2. Wait for the model to process the image.
|
9 |
+
3. View the colorized result!
|
10 |
+
|
11 |
+
This app uses a GAN model trained on the ImageNet dataset to add color to grayscale images.
|
|
app-mlflow.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import mlflow
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from skimage.color import rgb2lab, lab2rgb
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
from model import Generator
|
10 |
+
|
11 |
+
EXPERIMENT_NAME = "Colorizer_Experiment"
|
12 |
+
RUN_ID = "your_run_id_here" # Replace with your actual run ID
|
13 |
+
|
14 |
+
def setup_mlflow():
|
15 |
+
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
|
16 |
+
if experiment is None:
|
17 |
+
experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
|
18 |
+
else:
|
19 |
+
experiment_id = experiment.experiment_id
|
20 |
+
return experiment_id
|
21 |
+
|
22 |
+
def load_model(run_id, device):
|
23 |
+
print(f"Loading model from run: {run_id}")
|
24 |
+
model_uri = f"runs:/{run_id}/generator_model"
|
25 |
+
model = mlflow.pytorch.load_model(model_uri, map_location=device)
|
26 |
+
return model
|
27 |
+
|
28 |
+
def preprocess_image(image):
|
29 |
+
img = Image.fromarray(image).convert("RGB")
|
30 |
+
transform = transforms.Compose([
|
31 |
+
transforms.Resize((256, 256)),
|
32 |
+
transforms.ToTensor()
|
33 |
+
])
|
34 |
+
img_tensor = transform(img)
|
35 |
+
lab_img = rgb2lab(img_tensor.permute(1, 2, 0).numpy())
|
36 |
+
L = lab_img[:,:,0]
|
37 |
+
L = (L - 50) / 50
|
38 |
+
L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float()
|
39 |
+
return L
|
40 |
+
|
41 |
+
def postprocess_output(L, ab):
|
42 |
+
L = L.squeeze().cpu().numpy()
|
43 |
+
ab = ab.squeeze().cpu().numpy()
|
44 |
+
L = (L + 1.) * 50.
|
45 |
+
ab = ab * 128.
|
46 |
+
Lab = np.concatenate([L[..., np.newaxis], ab], axis=2)
|
47 |
+
rgb_img = lab2rgb(Lab)
|
48 |
+
return (rgb_img * 255).astype(np.uint8)
|
49 |
+
|
50 |
+
def colorize_image(image, model, device):
|
51 |
+
L = preprocess_image(image).to(device)
|
52 |
+
with torch.no_grad():
|
53 |
+
ab = model(L)
|
54 |
+
colorized = postprocess_output(L, ab)
|
55 |
+
return colorized
|
56 |
+
|
57 |
+
def setup_gradio_app(run_id, device):
|
58 |
+
model = load_model(run_id, device)
|
59 |
+
|
60 |
+
def gradio_colorize(input_image):
|
61 |
+
colorized = colorize_image(input_image, model, device)
|
62 |
+
return Image.fromarray(colorized)
|
63 |
+
|
64 |
+
iface = gr.Interface(
|
65 |
+
fn=gradio_colorize,
|
66 |
+
inputs=gr.Image(label="Upload a grayscale image"),
|
67 |
+
outputs=gr.Image(label="Colorized Image"),
|
68 |
+
title="Image Colorizer",
|
69 |
+
description="Upload a grayscale image and get a colorized version!",
|
70 |
+
)
|
71 |
+
|
72 |
+
return iface
|
app.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from skimage.color import rgb2lab, lab2rgb
|
7 |
+
import os
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
# Define the model architecture (same as in the training script)
|
11 |
+
class UNetBlock(nn.Module):
|
12 |
+
def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
|
13 |
+
super(UNetBlock, self).__init__()
|
14 |
+
self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \
|
15 |
+
else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
|
16 |
+
self.bn = nn.BatchNorm2d(out_channels) if bn else None
|
17 |
+
self.dropout = nn.Dropout(0.5) if dropout else None
|
18 |
+
self.down = down
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = self.conv(x)
|
22 |
+
if self.bn:
|
23 |
+
x = self.bn(x)
|
24 |
+
if self.dropout:
|
25 |
+
x = self.dropout(x)
|
26 |
+
return nn.ReLU()(x) if self.down else nn.ReLU(inplace=True)(x)
|
27 |
+
|
28 |
+
class Generator(nn.Module):
|
29 |
+
def __init__(self):
|
30 |
+
super(Generator, self).__init__()
|
31 |
+
self.down1 = UNetBlock(1, 64, bn=False)
|
32 |
+
self.down2 = UNetBlock(64, 128)
|
33 |
+
self.down3 = UNetBlock(128, 256)
|
34 |
+
self.down4 = UNetBlock(256, 512)
|
35 |
+
self.down5 = UNetBlock(512, 512)
|
36 |
+
self.down6 = UNetBlock(512, 512)
|
37 |
+
self.down7 = UNetBlock(512, 512)
|
38 |
+
self.down8 = UNetBlock(512, 512, bn=False)
|
39 |
+
|
40 |
+
self.up1 = UNetBlock(512, 512, down=False, dropout=True)
|
41 |
+
self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
|
42 |
+
self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
|
43 |
+
self.up4 = UNetBlock(1024, 512, down=False)
|
44 |
+
self.up5 = UNetBlock(1024, 256, down=False)
|
45 |
+
self.up6 = UNetBlock(512, 128, down=False)
|
46 |
+
self.up7 = UNetBlock(256, 64, down=False)
|
47 |
+
self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
d1 = self.down1(x)
|
51 |
+
d2 = self.down2(d1)
|
52 |
+
d3 = self.down3(d2)
|
53 |
+
d4 = self.down4(d3)
|
54 |
+
d5 = self.down5(d4)
|
55 |
+
d6 = self.down6(d5)
|
56 |
+
d7 = self.down7(d6)
|
57 |
+
d8 = self.down8(d7)
|
58 |
+
|
59 |
+
u1 = self.up1(d8)
|
60 |
+
u2 = self.up2(torch.cat([u1, d7], 1))
|
61 |
+
u3 = self.up3(torch.cat([u2, d6], 1))
|
62 |
+
u4 = self.up4(torch.cat([u3, d5], 1))
|
63 |
+
u5 = self.up5(torch.cat([u4, d4], 1))
|
64 |
+
u6 = self.up6(torch.cat([u5, d3], 1))
|
65 |
+
u7 = self.up7(torch.cat([u6, d2], 1))
|
66 |
+
return torch.tanh(self.up8(torch.cat([u7, d1], 1)))
|
67 |
+
|
68 |
+
# Load the checkpoint
|
69 |
+
def load_checkpoint(filename, generator, map_location):
|
70 |
+
if os.path.isfile(filename):
|
71 |
+
print(f"Loading checkpoint '{filename}'")
|
72 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
73 |
+
generator.load_state_dict(checkpoint['generator_state_dict'])
|
74 |
+
print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})")
|
75 |
+
else:
|
76 |
+
print(f"No checkpoint found at '{filename}'")
|
77 |
+
|
78 |
+
# Initialize the model
|
79 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
80 |
+
generator = Generator().to(device)
|
81 |
+
checkpoint_path = "checkpoints/latest_checkpoint.pth.tar"
|
82 |
+
load_checkpoint(checkpoint_path, generator, map_location=device)
|
83 |
+
generator.eval()
|
84 |
+
|
85 |
+
# Define the transformation
|
86 |
+
transform = transforms.Compose([
|
87 |
+
transforms.Resize((256, 256)),
|
88 |
+
transforms.Grayscale(num_output_channels=1), # Convert to grayscale
|
89 |
+
transforms.ToTensor()
|
90 |
+
])
|
91 |
+
|
92 |
+
# Define the inference function
|
93 |
+
def colorize_image(input_image):
|
94 |
+
try:
|
95 |
+
original_size = input_image.size
|
96 |
+
input_image = transform(input_image).unsqueeze(0).to(device)
|
97 |
+
with torch.no_grad():
|
98 |
+
output = generator(input_image)
|
99 |
+
output = output.squeeze(0).cpu().numpy()
|
100 |
+
L = input_image.squeeze(0).cpu().numpy()
|
101 |
+
L = (L + 1.) * 50.
|
102 |
+
ab = output * 128.
|
103 |
+
Lab = np.concatenate([L, ab], axis=0).transpose(1, 2, 0)
|
104 |
+
rgb_image = lab2rgb(Lab)
|
105 |
+
rgb_image = Image.fromarray((rgb_image * 255).astype(np.uint8))
|
106 |
+
rgb_image = rgb_image.resize(original_size, Image.LANCZOS)
|
107 |
+
return rgb_image
|
108 |
+
except Exception as e:
|
109 |
+
print(f"Error in colorize_image: {str(e)}")
|
110 |
+
return None
|
111 |
+
|
112 |
+
# Create the Gradio interface
|
113 |
+
iface = gr.Interface(
|
114 |
+
fn=colorize_image,
|
115 |
+
inputs=gr.Image(type="pil"),
|
116 |
+
outputs=gr.Image(type="pil"),
|
117 |
+
title="Image Colorizer",
|
118 |
+
description="Upload a grayscale image to colorize it."
|
119 |
+
)
|
120 |
+
|
121 |
+
# Launch the app
|
122 |
+
if __name__ == "__main__":
|
123 |
+
iface.launch()
|
colorizer_pipeline.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from torch.utils.data import IterableDataset, DataLoader
|
6 |
+
from torchvision import transforms
|
7 |
+
from torchvision.utils import make_grid
|
8 |
+
import mlflow
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
from skimage.color import rgb2lab, lab2rgb
|
13 |
+
from datasets import load_dataset
|
14 |
+
from PIL import Image
|
15 |
+
from itertools import islice
|
16 |
+
import traceback
|
17 |
+
|
18 |
+
# MLflow setup
|
19 |
+
EXPERIMENT_NAME = "Colorizer_Experiment"
|
20 |
+
|
21 |
+
def setup_mlflow():
|
22 |
+
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
|
23 |
+
if experiment is None:
|
24 |
+
experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
|
25 |
+
else:
|
26 |
+
experiment_id = experiment.experiment_id
|
27 |
+
return experiment_id
|
28 |
+
|
29 |
+
# Data ingestion
|
30 |
+
class ColorizeIterableDataset(IterableDataset):
|
31 |
+
def __init__(self, dataset, transform=None):
|
32 |
+
self.dataset = dataset
|
33 |
+
self.transform = transform
|
34 |
+
|
35 |
+
def __iter__(self):
|
36 |
+
for item in self.dataset:
|
37 |
+
try:
|
38 |
+
img = item['image']
|
39 |
+
if img.mode != 'RGB':
|
40 |
+
img = img.convert('RGB')
|
41 |
+
if self.transform:
|
42 |
+
img = self.transform(img)
|
43 |
+
|
44 |
+
# Add shape check after transform
|
45 |
+
if img.shape != (3, 256, 256):
|
46 |
+
print(f"Unexpected image shape after transform: {img.shape}")
|
47 |
+
continue
|
48 |
+
|
49 |
+
lab = rgb2lab(img.permute(1, 2, 0).numpy())
|
50 |
+
|
51 |
+
# Add shape check after rgb2lab conversion
|
52 |
+
if lab.shape != (256, 256, 3):
|
53 |
+
print(f"Unexpected lab shape: {lab.shape}")
|
54 |
+
continue
|
55 |
+
|
56 |
+
l_chan = lab[:, :, 0]
|
57 |
+
l_chan = (l_chan - 50) / 50
|
58 |
+
ab_chan = lab[:, :, 1:]
|
59 |
+
ab_chan = ab_chan / 128
|
60 |
+
|
61 |
+
yield torch.Tensor(l_chan).unsqueeze(0), torch.Tensor(ab_chan).permute(2, 0, 1)
|
62 |
+
except Exception as e:
|
63 |
+
print(f"Error processing image: {str(e)}")
|
64 |
+
continue
|
65 |
+
|
66 |
+
def create_dataloaders(batch_size=32):
|
67 |
+
try:
|
68 |
+
print("Loading ImageNet dataset in streaming mode...")
|
69 |
+
dataset = load_dataset("imagenet-1k", split="train", streaming=True)
|
70 |
+
print("Dataset loaded in streaming mode.")
|
71 |
+
|
72 |
+
print("Creating custom dataset...")
|
73 |
+
transform = transforms.Compose([
|
74 |
+
transforms.Resize((256, 256)), # Resize all images to 256x256
|
75 |
+
transforms.ToTensor()
|
76 |
+
])
|
77 |
+
train_dataset = ColorizeIterableDataset(dataset, transform=transform)
|
78 |
+
print("Custom dataset created.")
|
79 |
+
|
80 |
+
print("Creating dataloader...")
|
81 |
+
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4)
|
82 |
+
print("Dataloader created.")
|
83 |
+
|
84 |
+
return train_dataloader
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Error in create_dataloaders: {str(e)}")
|
87 |
+
return None
|
88 |
+
|
89 |
+
# Model definition
|
90 |
+
class UNetBlock(nn.Module):
|
91 |
+
def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
|
92 |
+
super(UNetBlock, self).__init__()
|
93 |
+
self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \
|
94 |
+
else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
|
95 |
+
self.bn = nn.BatchNorm2d(out_channels) if bn else None
|
96 |
+
self.dropout = nn.Dropout(0.5) if dropout else None
|
97 |
+
self.down = down
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
x = self.conv(x)
|
101 |
+
if self.bn:
|
102 |
+
x = self.bn(x)
|
103 |
+
if self.dropout:
|
104 |
+
x = self.dropout(x)
|
105 |
+
return nn.ReLU()(x) if self.down else nn.ReLU(inplace=True)(x)
|
106 |
+
|
107 |
+
class Generator(nn.Module):
|
108 |
+
def __init__(self):
|
109 |
+
super(Generator, self).__init__()
|
110 |
+
self.down1 = UNetBlock(1, 64, bn=False)
|
111 |
+
self.down2 = UNetBlock(64, 128)
|
112 |
+
self.down3 = UNetBlock(128, 256)
|
113 |
+
self.down4 = UNetBlock(256, 512)
|
114 |
+
self.down5 = UNetBlock(512, 512)
|
115 |
+
self.down6 = UNetBlock(512, 512)
|
116 |
+
self.down7 = UNetBlock(512, 512)
|
117 |
+
self.down8 = UNetBlock(512, 512, bn=False)
|
118 |
+
|
119 |
+
self.up1 = UNetBlock(512, 512, down=False, dropout=True)
|
120 |
+
self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
|
121 |
+
self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
|
122 |
+
self.up4 = UNetBlock(1024, 512, down=False)
|
123 |
+
self.up5 = UNetBlock(1024, 256, down=False)
|
124 |
+
self.up6 = UNetBlock(512, 128, down=False)
|
125 |
+
self.up7 = UNetBlock(256, 64, down=False)
|
126 |
+
self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
d1 = self.down1(x)
|
130 |
+
d2 = self.down2(d1)
|
131 |
+
d3 = self.down3(d2)
|
132 |
+
d4 = self.down4(d3)
|
133 |
+
d5 = self.down5(d4)
|
134 |
+
d6 = self.down6(d5)
|
135 |
+
d7 = self.down7(d6)
|
136 |
+
d8 = self.down8(d7)
|
137 |
+
|
138 |
+
u1 = self.up1(d8)
|
139 |
+
u2 = self.up2(torch.cat([u1, d7], 1))
|
140 |
+
u3 = self.up3(torch.cat([u2, d6], 1))
|
141 |
+
u4 = self.up4(torch.cat([u3, d5], 1))
|
142 |
+
u5 = self.up5(torch.cat([u4, d4], 1))
|
143 |
+
u6 = self.up6(torch.cat([u5, d3], 1))
|
144 |
+
u7 = self.up7(torch.cat([u6, d2], 1))
|
145 |
+
return torch.tanh(self.up8(torch.cat([u7, d1], 1)))
|
146 |
+
|
147 |
+
class Discriminator(nn.Module):
|
148 |
+
def __init__(self):
|
149 |
+
super(Discriminator, self).__init__()
|
150 |
+
self.model = nn.Sequential(
|
151 |
+
nn.Conv2d(3, 64, 4, stride=2, padding=1),
|
152 |
+
nn.LeakyReLU(0.2, inplace=True),
|
153 |
+
nn.Conv2d(64, 128, 4, stride=2, padding=1),
|
154 |
+
nn.BatchNorm2d(128),
|
155 |
+
nn.LeakyReLU(0.2, inplace=True),
|
156 |
+
nn.Conv2d(128, 256, 4, stride=2, padding=1),
|
157 |
+
nn.BatchNorm2d(256),
|
158 |
+
nn.LeakyReLU(0.2, inplace=True),
|
159 |
+
nn.Conv2d(256, 512, 4, padding=1),
|
160 |
+
nn.BatchNorm2d(512),
|
161 |
+
nn.LeakyReLU(0.2, inplace=True),
|
162 |
+
nn.Conv2d(512, 1, 4, padding=1)
|
163 |
+
)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
return self.model(x)
|
167 |
+
|
168 |
+
def init_weights(model):
|
169 |
+
classname = model.__class__.__name__
|
170 |
+
if classname.find('Conv') != -1:
|
171 |
+
nn.init.normal_(model.weight.data, 0.0, 0.02)
|
172 |
+
elif classname.find('BatchNorm') != -1:
|
173 |
+
nn.init.normal_(model.weight.data, 1.0, 0.02)
|
174 |
+
nn.init.constant_(model.bias.data, 0)
|
175 |
+
|
176 |
+
# Training utilities
|
177 |
+
def lab_to_rgb(L, ab):
|
178 |
+
L = (L + 1.) * 50.
|
179 |
+
ab = ab * 128.
|
180 |
+
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
|
181 |
+
rgb_imgs = []
|
182 |
+
for img in Lab:
|
183 |
+
img_rgb = lab2rgb(img)
|
184 |
+
rgb_imgs.append(img_rgb)
|
185 |
+
return np.stack(rgb_imgs, axis=0)
|
186 |
+
|
187 |
+
def visualize_results(epoch, generator, train_loader, device):
|
188 |
+
generator.eval()
|
189 |
+
with torch.no_grad():
|
190 |
+
try:
|
191 |
+
for inputs, real_AB in train_loader:
|
192 |
+
print(f"Input shape: {inputs.shape}, real_AB shape: {real_AB.shape}")
|
193 |
+
|
194 |
+
# Ensure inputs have the correct shape (B, 1, H, W)
|
195 |
+
if inputs.shape[1] != 1:
|
196 |
+
inputs = inputs.unsqueeze(1)
|
197 |
+
|
198 |
+
inputs, real_AB = inputs.to(device), real_AB.to(device)
|
199 |
+
fake_AB = generator(inputs)
|
200 |
+
|
201 |
+
print(f"fake_AB shape: {fake_AB.shape}")
|
202 |
+
|
203 |
+
# Ensure fake_AB and real_AB have the correct shape (B, 2, H, W)
|
204 |
+
if fake_AB.shape[1] != 2:
|
205 |
+
fake_AB = fake_AB.view(fake_AB.shape[0], 2, fake_AB.shape[2], fake_AB.shape[3])
|
206 |
+
if real_AB.shape[1] != 2:
|
207 |
+
real_AB = real_AB.view(real_AB.shape[0], 2, real_AB.shape[2], real_AB.shape[3])
|
208 |
+
|
209 |
+
fake_rgb = lab_to_rgb(inputs.cpu(), fake_AB.cpu())
|
210 |
+
real_rgb = lab_to_rgb(inputs.cpu(), real_AB.cpu())
|
211 |
+
|
212 |
+
print(f"fake_rgb shape: {fake_rgb.shape}, real_rgb shape: {real_rgb.shape}")
|
213 |
+
|
214 |
+
concatenated = np.concatenate([real_rgb, fake_rgb], axis=2) # Changed axis from 3 to 2
|
215 |
+
print(f"Concatenated shape: {concatenated.shape}")
|
216 |
+
|
217 |
+
img_grid = make_grid(torch.from_numpy(concatenated).permute(0, 3, 1, 2), normalize=True, nrow=4)
|
218 |
+
|
219 |
+
plt.figure(figsize=(15, 15))
|
220 |
+
plt.imshow(img_grid.permute(1, 2, 0).cpu())
|
221 |
+
plt.axis('off')
|
222 |
+
plt.title(f'Epoch {epoch}')
|
223 |
+
plt.savefig(f'results/epoch_{epoch}.png')
|
224 |
+
mlflow.log_artifact(f'results/epoch_{epoch}.png')
|
225 |
+
plt.close()
|
226 |
+
break
|
227 |
+
except Exception as e:
|
228 |
+
print(f"Error in visualize_results: {str(e)}")
|
229 |
+
traceback.print_exc()
|
230 |
+
generator.train()
|
231 |
+
|
232 |
+
def save_checkpoint(state, filename="checkpoint.pth.tar"):
|
233 |
+
torch.save(state, filename)
|
234 |
+
mlflow.log_artifact(filename)
|
235 |
+
|
236 |
+
def load_checkpoint(filename, generator, discriminator, optimizerG, optimizerD):
|
237 |
+
if os.path.isfile(filename):
|
238 |
+
print(f"Loading checkpoint '{filename}'")
|
239 |
+
checkpoint = torch.load(filename)
|
240 |
+
start_epoch = checkpoint['epoch'] + 1
|
241 |
+
generator.load_state_dict(checkpoint['generator_state_dict'])
|
242 |
+
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
|
243 |
+
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
|
244 |
+
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
|
245 |
+
print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})")
|
246 |
+
return start_epoch
|
247 |
+
else:
|
248 |
+
print(f"No checkpoint found at '{filename}'")
|
249 |
+
return 0
|
250 |
+
|
251 |
+
# Training function
|
252 |
+
def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002, beta1=0.5):
|
253 |
+
criterion = nn.BCEWithLogitsLoss()
|
254 |
+
l1_loss = nn.L1Loss()
|
255 |
+
|
256 |
+
optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
|
257 |
+
optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
|
258 |
+
|
259 |
+
checkpoint_dir = "checkpoints"
|
260 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
261 |
+
os.makedirs("results", exist_ok=True)
|
262 |
+
|
263 |
+
checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth.tar")
|
264 |
+
start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, optimizerG, optimizerD)
|
265 |
+
|
266 |
+
experiment_id = setup_mlflow()
|
267 |
+
with mlflow.start_run(experiment_id=experiment_id, run_name="training_run") as run:
|
268 |
+
try:
|
269 |
+
for epoch in range(start_epoch, num_epochs):
|
270 |
+
generator.train()
|
271 |
+
discriminator.train()
|
272 |
+
|
273 |
+
num_iterations = 2
|
274 |
+
pbar = tqdm(enumerate(islice(train_loader, num_iterations)), total=num_iterations, desc=f"Epoch {epoch+1}/{num_epochs}")
|
275 |
+
|
276 |
+
for i, (real_L, real_AB) in pbar:
|
277 |
+
# Add shape check
|
278 |
+
if real_L.shape[1:] != (1, 256, 256) or real_AB.shape[1:] != (2, 256, 256):
|
279 |
+
print(f"Unexpected tensor shapes: real_L {real_L.shape}, real_AB {real_AB.shape}")
|
280 |
+
continue
|
281 |
+
|
282 |
+
real_L, real_AB = real_L.to(device), real_AB.to(device)
|
283 |
+
batch_size = real_L.size(0)
|
284 |
+
|
285 |
+
# Train Discriminator
|
286 |
+
optimizerD.zero_grad()
|
287 |
+
|
288 |
+
fake_AB = generator(real_L)
|
289 |
+
fake_LAB = torch.cat([real_L, fake_AB], dim=1)
|
290 |
+
real_LAB = torch.cat([real_L, real_AB], dim=1)
|
291 |
+
|
292 |
+
pred_fake = discriminator(fake_LAB.detach())
|
293 |
+
loss_D_fake = criterion(pred_fake, torch.zeros_like(pred_fake))
|
294 |
+
|
295 |
+
pred_real = discriminator(real_LAB)
|
296 |
+
loss_D_real = criterion(pred_real, torch.ones_like(pred_real))
|
297 |
+
|
298 |
+
loss_D = (loss_D_fake + loss_D_real) * 0.5
|
299 |
+
loss_D.backward()
|
300 |
+
optimizerD.step()
|
301 |
+
|
302 |
+
# Train Generator
|
303 |
+
optimizerG.zero_grad()
|
304 |
+
|
305 |
+
fake_AB = generator(real_L)
|
306 |
+
fake_LAB = torch.cat([real_L, fake_AB], dim=1)
|
307 |
+
pred_fake = discriminator(fake_LAB)
|
308 |
+
|
309 |
+
loss_G_GAN = criterion(pred_fake, torch.ones_like(pred_fake))
|
310 |
+
loss_G_L1 = l1_loss(fake_AB, real_AB) * 100 # L1 loss weight
|
311 |
+
|
312 |
+
loss_G = loss_G_GAN + loss_G_L1
|
313 |
+
loss_G.backward()
|
314 |
+
optimizerG.step()
|
315 |
+
|
316 |
+
pbar.set_postfix({
|
317 |
+
'D_loss': loss_D.item(),
|
318 |
+
'G_loss': loss_G.item(),
|
319 |
+
'G_L1': loss_G_L1.item()
|
320 |
+
})
|
321 |
+
|
322 |
+
mlflow.log_metrics({
|
323 |
+
"D_loss": loss_D.item(),
|
324 |
+
"G_loss": loss_G.item(),
|
325 |
+
"G_L1_loss": loss_G_L1.item()
|
326 |
+
}, step=epoch * num_iterations + i)
|
327 |
+
|
328 |
+
visualize_results(epoch, generator, train_loader, device)
|
329 |
+
|
330 |
+
checkpoint = {
|
331 |
+
'epoch': epoch,
|
332 |
+
'generator_state_dict': generator.state_dict(),
|
333 |
+
'discriminator_state_dict': discriminator.state_dict(),
|
334 |
+
'optimizerG_state_dict': optimizerG.state_dict(),
|
335 |
+
'optimizerD_state_dict': optimizerD.state_dict(),
|
336 |
+
}
|
337 |
+
save_checkpoint(checkpoint, filename=checkpoint_path)
|
338 |
+
|
339 |
+
print("Training completed successfully.")
|
340 |
+
|
341 |
+
mlflow.pytorch.log_model(generator, "generator_model")
|
342 |
+
model_uri = f"runs:/{run.info.run_id}/generator_model"
|
343 |
+
mlflow.register_model(model_uri, "colorizer_generator")
|
344 |
+
|
345 |
+
return run.info.run_id
|
346 |
+
|
347 |
+
except Exception as e:
|
348 |
+
print(f"Error during training: {str(e)}")
|
349 |
+
mlflow.log_param("error", str(e))
|
350 |
+
return None
|
351 |
+
|
352 |
+
# Main execution
|
353 |
+
if __name__ == "__main__":
|
354 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
355 |
+
print(f"Using device: {device}")
|
356 |
+
|
357 |
+
try:
|
358 |
+
batch_size = 32
|
359 |
+
num_epochs = 50
|
360 |
+
|
361 |
+
train_loader = create_dataloaders(batch_size=batch_size)
|
362 |
+
if train_loader is None:
|
363 |
+
raise Exception("Failed to create dataloader")
|
364 |
+
|
365 |
+
generator = Generator().to(device)
|
366 |
+
discriminator = Discriminator().to(device)
|
367 |
+
|
368 |
+
generator.apply(init_weights)
|
369 |
+
discriminator.apply(init_weights)
|
370 |
+
|
371 |
+
run_id = train(generator, discriminator, train_loader, num_epochs=num_epochs, device=device)
|
372 |
+
if run_id:
|
373 |
+
print(f"Training completed successfully. Run ID: {run_id}")
|
374 |
+
with open("latest_run_id.txt", "w") as f:
|
375 |
+
f.write(run_id)
|
376 |
+
else:
|
377 |
+
print("Training failed!")
|
378 |
+
except Exception as e:
|
379 |
+
print(f"Critical error in main execution: {str(e)}")
|
data_ingestion.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import mlflow
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import IterableDataset, DataLoader
|
5 |
+
from torchvision import transforms
|
6 |
+
from datasets import load_dataset
|
7 |
+
from skimage.color import rgb2lab
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
class ColorizeIterableDataset(IterableDataset):
|
13 |
+
def __init__(self, dataset):
|
14 |
+
self.dataset = dataset
|
15 |
+
self.transform = transforms.Compose([
|
16 |
+
transforms.Resize((256, 256)),
|
17 |
+
transforms.ToTensor()
|
18 |
+
])
|
19 |
+
|
20 |
+
def __iter__(self):
|
21 |
+
for item in self.dataset:
|
22 |
+
try:
|
23 |
+
img = item['image']
|
24 |
+
if img.mode != 'RGB':
|
25 |
+
img = img.convert('RGB')
|
26 |
+
img = self.transform(img)
|
27 |
+
|
28 |
+
# Convert to LAB color space
|
29 |
+
lab = rgb2lab(img.permute(1, 2, 0).numpy())
|
30 |
+
|
31 |
+
# Normalize L channel to range [-1, 1]
|
32 |
+
l_chan = lab[:, :, 0]
|
33 |
+
l_chan = (l_chan - 50) / 50
|
34 |
+
|
35 |
+
# Normalize AB channels to range [-1, 1]
|
36 |
+
ab_chan = lab[:, :, 1:]
|
37 |
+
ab_chan = ab_chan / 128
|
38 |
+
|
39 |
+
yield torch.Tensor(l_chan).unsqueeze(0), torch.Tensor(ab_chan).permute(2, 0, 1)
|
40 |
+
except Exception as e:
|
41 |
+
print(f"Error processing image: {str(e)}")
|
42 |
+
continue
|
43 |
+
|
44 |
+
def create_dataloaders(batch_size=32):
|
45 |
+
try:
|
46 |
+
print("Loading ImageNet dataset in streaming mode...")
|
47 |
+
# Load ImageNet dataset from Hugging Face in streaming mode
|
48 |
+
dataset = load_dataset("imagenet-1k", split="train", streaming=True)
|
49 |
+
print("Dataset loaded in streaming mode.")
|
50 |
+
|
51 |
+
print("Creating custom dataset...")
|
52 |
+
# Create custom dataset
|
53 |
+
train_dataset = ColorizeIterableDataset(dataset)
|
54 |
+
print("Custom dataset created.")
|
55 |
+
|
56 |
+
print("Creating dataloader...")
|
57 |
+
# Create dataloader
|
58 |
+
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4)
|
59 |
+
print("Dataloader created.")
|
60 |
+
|
61 |
+
return train_dataloader
|
62 |
+
except Exception as e:
|
63 |
+
print(f"Error in create_dataloaders: {str(e)}")
|
64 |
+
return None
|
65 |
+
|
66 |
+
def test_data_ingestion():
|
67 |
+
print("Testing data ingestion...")
|
68 |
+
try:
|
69 |
+
dataloader = create_dataloaders(batch_size=4)
|
70 |
+
if dataloader is None:
|
71 |
+
raise Exception("Dataloader creation failed")
|
72 |
+
|
73 |
+
# Get the first batch
|
74 |
+
for sample_batch in dataloader:
|
75 |
+
if len(sample_batch) != 2:
|
76 |
+
raise Exception(f"Unexpected batch format: {len(sample_batch)} elements instead of 2")
|
77 |
+
|
78 |
+
l_chan, ab_chan = sample_batch
|
79 |
+
if l_chan.shape != torch.Size([4, 1, 256, 256]) or ab_chan.shape != torch.Size([4, 2, 256, 256]):
|
80 |
+
raise Exception(f"Unexpected tensor shapes: L={l_chan.shape}, AB={ab_chan.shape}")
|
81 |
+
|
82 |
+
print("Data ingestion test passed.")
|
83 |
+
return True
|
84 |
+
except Exception as e:
|
85 |
+
print(f"Data ingestion test failed: {str(e)}")
|
86 |
+
return False
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
try:
|
90 |
+
print("Starting data ingestion pipeline...")
|
91 |
+
mlflow.start_run(run_name="data_ingestion")
|
92 |
+
|
93 |
+
try:
|
94 |
+
# Log parameters
|
95 |
+
print("Logging parameters...")
|
96 |
+
mlflow.log_param("batch_size", 32)
|
97 |
+
mlflow.log_param("dataset", "imagenet-1k")
|
98 |
+
print("Parameters logged.")
|
99 |
+
|
100 |
+
# Create dataloaders
|
101 |
+
print("Creating dataloaders...")
|
102 |
+
train_dataloader = create_dataloaders(batch_size=32)
|
103 |
+
if train_dataloader is None:
|
104 |
+
raise Exception("Failed to create dataloader")
|
105 |
+
print("Dataloaders created successfully.")
|
106 |
+
|
107 |
+
# Log a sample batch
|
108 |
+
print("Logging sample batch...")
|
109 |
+
for sample_batch in train_dataloader:
|
110 |
+
l_chan, ab_chan = sample_batch
|
111 |
+
|
112 |
+
# Log sample input (L channel)
|
113 |
+
sample_input = l_chan[0].numpy()
|
114 |
+
mlflow.log_image(sample_input, "sample_input_l_channel.png")
|
115 |
+
|
116 |
+
# Log sample target (AB channels)
|
117 |
+
sample_target = ab_chan[0].permute(1, 2, 0).numpy()
|
118 |
+
mlflow.log_image(sample_target, "sample_target_ab_channels.png")
|
119 |
+
|
120 |
+
print("Sample batch logged.")
|
121 |
+
break # We only need one batch for logging
|
122 |
+
|
123 |
+
print("Data ingestion pipeline completed successfully.")
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Error in data ingestion pipeline: {str(e)}")
|
127 |
+
mlflow.log_param("error", str(e))
|
128 |
+
|
129 |
+
finally:
|
130 |
+
mlflow.end_run()
|
131 |
+
|
132 |
+
except Exception as e:
|
133 |
+
print(f"Critical error in main execution: {str(e)}")
|
134 |
+
|
135 |
+
if test_data_ingestion():
|
136 |
+
print("All tests passed.")
|
137 |
+
else:
|
138 |
+
print("Tests failed.")
|
inference.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import mlflow
|
4 |
+
import mlflow.pytorch
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
from skimage.color import rgb2lab, lab2rgb
|
8 |
+
from torchvision import transforms
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
from model import Generator
|
12 |
+
|
13 |
+
EXPERIMENT_NAME = "Colorizer_Experiment"
|
14 |
+
|
15 |
+
def setup_mlflow():
|
16 |
+
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
|
17 |
+
if experiment is None:
|
18 |
+
experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
|
19 |
+
else:
|
20 |
+
experiment_id = experiment.experiment_id
|
21 |
+
return experiment_id
|
22 |
+
|
23 |
+
def load_model(run_id, device):
|
24 |
+
print(f"Loading model from run: {run_id}")
|
25 |
+
model_uri = f"runs:/{run_id}/generator_model"
|
26 |
+
model = mlflow.pytorch.load_model(model_uri, map_location=device)
|
27 |
+
return model
|
28 |
+
|
29 |
+
# Configuration variables
|
30 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
RUN_ID = "your_run_id_here" # Replace with the actual run ID
|
32 |
+
IMAGE_PATH = "path/to/your/image.jpg" # Replace with the path to your input image
|
33 |
+
SAVE_MODEL = False
|
34 |
+
SERVE_MODEL = False
|
35 |
+
SERVE_PORT = 5000
|
36 |
+
|
37 |
+
def preprocess_image(image_path):
|
38 |
+
img = Image.open(image_path).convert("RGB")
|
39 |
+
transform = transforms.Compose([
|
40 |
+
transforms.Resize((256, 256)),
|
41 |
+
transforms.ToTensor()
|
42 |
+
])
|
43 |
+
img_tensor = transform(img)
|
44 |
+
lab_img = rgb2lab(img_tensor.permute(1, 2, 0).numpy())
|
45 |
+
L = lab_img[:,:,0]
|
46 |
+
L = (L - 50) / 50
|
47 |
+
L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float()
|
48 |
+
return L
|
49 |
+
|
50 |
+
def postprocess_output(L, ab):
|
51 |
+
L = L.squeeze().cpu().numpy()
|
52 |
+
ab = ab.squeeze().cpu().numpy()
|
53 |
+
L = (L + 1.) * 50.
|
54 |
+
ab = ab * 128.
|
55 |
+
Lab = np.concatenate([L[..., np.newaxis], ab], axis=2)
|
56 |
+
rgb_img = lab2rgb(Lab)
|
57 |
+
return (rgb_img * 255).astype(np.uint8)
|
58 |
+
|
59 |
+
def colorize_image(model, image_path, device):
|
60 |
+
L = preprocess_image(image_path).to(device)
|
61 |
+
with torch.no_grad():
|
62 |
+
ab = model(L)
|
63 |
+
colorized = postprocess_output(L, ab)
|
64 |
+
return colorized
|
65 |
+
|
66 |
+
def save_model(model, run_id):
|
67 |
+
with mlflow.start_run(run_id=run_id):
|
68 |
+
# Log the model
|
69 |
+
mlflow.pytorch.log_model(model, "model")
|
70 |
+
|
71 |
+
# Register the model
|
72 |
+
model_uri = f"runs:/{run_id}/model"
|
73 |
+
mlflow.register_model(model_uri, "colorizer_model")
|
74 |
+
|
75 |
+
print(f"Model saved and registered with run_id: {run_id}")
|
76 |
+
|
77 |
+
def serve_model(run_id, port=5000):
|
78 |
+
model_uri = f"runs:/{run_id}/model"
|
79 |
+
mlflow.pytorch.serve(model_uri, port=port)
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
parser = argparse.ArgumentParser(description="Colorizer Inference")
|
83 |
+
parser.add_argument("--run_id", type=str, help="MLflow run ID of the trained model")
|
84 |
+
parser.add_argument("--image_path", type=str, required=True, help="Path to the input grayscale image")
|
85 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
|
86 |
+
help="Device to use for inference (cuda/cpu)")
|
87 |
+
args = parser.parse_args()
|
88 |
+
|
89 |
+
device = torch.device(args.device)
|
90 |
+
print(f"Using device: {device}")
|
91 |
+
|
92 |
+
# If run_id is not provided, try to load it from the file
|
93 |
+
if not args.run_id:
|
94 |
+
try:
|
95 |
+
with open("latest_run_id.txt", "r") as f:
|
96 |
+
args.run_id = f.read().strip()
|
97 |
+
except FileNotFoundError:
|
98 |
+
print("No run ID provided and couldn't find latest_run_id.txt")
|
99 |
+
exit(1)
|
100 |
+
|
101 |
+
experiment_id = setup_mlflow()
|
102 |
+
with mlflow.start_run(experiment_id=experiment_id, run_name="inference_run"):
|
103 |
+
try:
|
104 |
+
model = load_model(args.run_id, device)
|
105 |
+
colorized = colorize_image(model, args.image_path, device)
|
106 |
+
output_path = f"colorized_{os.path.basename(args.image_path)}"
|
107 |
+
Image.fromarray(colorized).save(output_path)
|
108 |
+
print(f"Colorized image saved as: {output_path}")
|
109 |
+
|
110 |
+
mlflow.log_artifact(output_path)
|
111 |
+
mlflow.log_param("input_image", args.image_path)
|
112 |
+
mlflow.log_param("model_run_id", args.run_id)
|
113 |
+
except Exception as e:
|
114 |
+
print(f"Error during inference: {str(e)}")
|
115 |
+
mlflow.log_param("error", str(e))
|
116 |
+
finally:
|
117 |
+
mlflow.end_run()
|
instructions.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Create a Colorizer webapp. This is a professional project and should follow all necessary coding practices, machine learning as well as MLops practices. Use python, mlflow, etc as well as other libraries and tools of your choice which might be required to achieve a fully functioning app. Can deploy it on gradio.
|
2 |
+
1. Gather a large and diverse dataset of color images. Convert these images to grayscale to create the paried training dataset for the Pix2Pix model.
|
3 |
+
2. Implement a Pix2Pix GAN architecture, utilizing U net generator for colorization and CNN discriminator for realism evaluation. Ensure the model processes images in the LAB color space.
|
4 |
+
3. Define the loss functions for training: l1 loss for pixel accuracy and adversarial loss for generating realistic images.
|
5 |
+
4. Train the Pix2Pix model using grayscale images as input and color images as targets. Monitor validation performance and apply data augmentation techniques.
|
6 |
+
5. Evaluate the trained model on the test dataset using PSNR and SSIM metrics. generate and visually inspect colorized images.
|
7 |
+
5. Develop a user interface for application that allows users to upload images. use gradio for this. Integrate the trained model and ensure the output maintains the original image size and quality.
|
8 |
+
6. Add post processing features to enhance the quality of colorized images. implement a feedback mechanism for users.
|
9 |
+
7. Perform user testing to gather feedback on application. Use this feedback to iterate and improve the model and user interface.
|
10 |
+
|
11 |
+
This is a professional project and should follow all necessary coding practices, machine learning as well as MLops practices. Use python, mlflow, etc as well as other libraries and tools of your choice which might be required to achieve a fully functioning app.
|
12 |
+
All the functions should be completely implemented and working.
|
13 |
+
The project should be containerized and should be able to run in a docker container.
|
model.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class UNetBlock(nn.Module):
|
6 |
+
def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
|
7 |
+
super(UNetBlock, self).__init__()
|
8 |
+
self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \
|
9 |
+
else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
|
10 |
+
self.bn = nn.BatchNorm2d(out_channels) if bn else None
|
11 |
+
self.dropout = nn.Dropout(0.5) if dropout else None
|
12 |
+
self.down = down
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
x = self.conv(x)
|
16 |
+
if self.bn:
|
17 |
+
x = self.bn(x)
|
18 |
+
if self.dropout:
|
19 |
+
x = self.dropout(x)
|
20 |
+
return F.relu(x) if self.down else F.relu(x, inplace=True)
|
21 |
+
|
22 |
+
class Generator(nn.Module):
|
23 |
+
def __init__(self):
|
24 |
+
super(Generator, self).__init__()
|
25 |
+
self.down1 = UNetBlock(1, 64, bn=False) # Input is L channel (1 channel)
|
26 |
+
self.down2 = UNetBlock(64, 128)
|
27 |
+
self.down3 = UNetBlock(128, 256)
|
28 |
+
self.down4 = UNetBlock(256, 512)
|
29 |
+
self.down5 = UNetBlock(512, 512)
|
30 |
+
self.down6 = UNetBlock(512, 512)
|
31 |
+
self.down7 = UNetBlock(512, 512)
|
32 |
+
self.down8 = UNetBlock(512, 512, bn=False)
|
33 |
+
|
34 |
+
self.up1 = UNetBlock(512, 512, down=False, dropout=True)
|
35 |
+
self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
|
36 |
+
self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
|
37 |
+
self.up4 = UNetBlock(1024, 512, down=False)
|
38 |
+
self.up5 = UNetBlock(1024, 256, down=False)
|
39 |
+
self.up6 = UNetBlock(512, 128, down=False)
|
40 |
+
self.up7 = UNetBlock(256, 64, down=False)
|
41 |
+
self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1) # Output is AB channels (2 channels)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
d1 = self.down1(x)
|
45 |
+
d2 = self.down2(d1)
|
46 |
+
d3 = self.down3(d2)
|
47 |
+
d4 = self.down4(d3)
|
48 |
+
d5 = self.down5(d4)
|
49 |
+
d6 = self.down6(d5)
|
50 |
+
d7 = self.down7(d6)
|
51 |
+
d8 = self.down8(d7)
|
52 |
+
|
53 |
+
u1 = self.up1(d8)
|
54 |
+
u2 = self.up2(torch.cat([u1, d7], 1))
|
55 |
+
u3 = self.up3(torch.cat([u2, d6], 1))
|
56 |
+
u4 = self.up4(torch.cat([u3, d5], 1))
|
57 |
+
u5 = self.up5(torch.cat([u4, d4], 1))
|
58 |
+
u6 = self.up6(torch.cat([u5, d3], 1))
|
59 |
+
u7 = self.up7(torch.cat([u6, d2], 1))
|
60 |
+
return torch.tanh(self.up8(torch.cat([u7, d1], 1)))
|
61 |
+
|
62 |
+
class Discriminator(nn.Module):
|
63 |
+
def __init__(self):
|
64 |
+
super(Discriminator, self).__init__()
|
65 |
+
self.model = nn.Sequential(
|
66 |
+
nn.Conv2d(3, 64, 4, stride=2, padding=1), # Input is L+AB (3 channels)
|
67 |
+
nn.LeakyReLU(0.2, inplace=True),
|
68 |
+
nn.Conv2d(64, 128, 4, stride=2, padding=1),
|
69 |
+
nn.BatchNorm2d(128),
|
70 |
+
nn.LeakyReLU(0.2, inplace=True),
|
71 |
+
nn.Conv2d(128, 256, 4, stride=2, padding=1),
|
72 |
+
nn.BatchNorm2d(256),
|
73 |
+
nn.LeakyReLU(0.2, inplace=True),
|
74 |
+
nn.Conv2d(256, 512, 4, padding=1),
|
75 |
+
nn.BatchNorm2d(512),
|
76 |
+
nn.LeakyReLU(0.2, inplace=True),
|
77 |
+
nn.Conv2d(512, 1, 4, padding=1)
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
return self.model(x)
|
82 |
+
|
83 |
+
def init_weights(model):
|
84 |
+
classname = model.__class__.__name__
|
85 |
+
if classname.find('Conv') != -1:
|
86 |
+
nn.init.normal_(model.weight.data, 0.0, 0.02)
|
87 |
+
elif classname.find('BatchNorm') != -1:
|
88 |
+
nn.init.normal_(model.weight.data, 1.0, 0.02)
|
89 |
+
nn.init.constant_(model.bias.data, 0)
|
90 |
+
|
91 |
+
def create_models():
|
92 |
+
try:
|
93 |
+
print("Creating Generator...")
|
94 |
+
generator = Generator()
|
95 |
+
generator.apply(init_weights)
|
96 |
+
print("Generator created successfully.")
|
97 |
+
|
98 |
+
print("Creating Discriminator...")
|
99 |
+
discriminator = Discriminator()
|
100 |
+
discriminator.apply(init_weights)
|
101 |
+
print("Discriminator created successfully.")
|
102 |
+
|
103 |
+
return generator, discriminator
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error in creating models: {str(e)}")
|
106 |
+
return None, None
|
107 |
+
|
108 |
+
def test_models():
|
109 |
+
print("Testing models...")
|
110 |
+
try:
|
111 |
+
generator, discriminator = create_models()
|
112 |
+
if generator is None or discriminator is None:
|
113 |
+
raise Exception("Model creation failed")
|
114 |
+
|
115 |
+
test_input_g = torch.randn(1, 1, 256, 256)
|
116 |
+
test_output_g = generator(test_input_g)
|
117 |
+
if test_output_g.shape != torch.Size([1, 2, 256, 256]):
|
118 |
+
raise Exception(f"Unexpected generator output shape: {test_output_g.shape}")
|
119 |
+
|
120 |
+
test_input_d = torch.randn(1, 3, 256, 256)
|
121 |
+
test_output_d = discriminator(test_input_d)
|
122 |
+
if test_output_d.shape != torch.Size([1, 1, 30, 30]):
|
123 |
+
raise Exception(f"Unexpected discriminator output shape: {test_output_d.shape}")
|
124 |
+
|
125 |
+
print("Model test passed.")
|
126 |
+
return True
|
127 |
+
except Exception as e:
|
128 |
+
print(f"Model test failed: {str(e)}")
|
129 |
+
return False
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
try:
|
133 |
+
print("Initializing models...")
|
134 |
+
generator, discriminator = create_models()
|
135 |
+
|
136 |
+
if generator is None or discriminator is None:
|
137 |
+
raise Exception("Failed to create models")
|
138 |
+
|
139 |
+
if not test_models():
|
140 |
+
raise Exception("Model testing failed")
|
141 |
+
|
142 |
+
print("Model creation and testing completed successfully.")
|
143 |
+
except Exception as e:
|
144 |
+
print(f"Critical error in main execution: {str(e)}")
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
mlflow
|
5 |
+
numpy
|
6 |
+
Pillow
|
7 |
+
scikit-image
|
run_colorizer.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import mlflow
|
5 |
+
from data_ingestion import create_dataloaders, test_data_ingestion
|
6 |
+
from model import Generator, Discriminator, init_weights, test_models
|
7 |
+
from train import train, test_training
|
8 |
+
from app import setup_gradio_app
|
9 |
+
|
10 |
+
EXPERIMENT_NAME = "Colorizer_Experiment"
|
11 |
+
|
12 |
+
def setup_mlflow():
|
13 |
+
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
|
14 |
+
if experiment is None:
|
15 |
+
experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
|
16 |
+
else:
|
17 |
+
experiment_id = experiment.experiment_id
|
18 |
+
return experiment_id
|
19 |
+
|
20 |
+
def run_pipeline(args):
|
21 |
+
device = torch.device(args.device)
|
22 |
+
print(f"Using device: {device}")
|
23 |
+
|
24 |
+
experiment_id = setup_mlflow()
|
25 |
+
|
26 |
+
if args.ingest_data or args.run_all:
|
27 |
+
print("Starting data ingestion...")
|
28 |
+
train_loader = create_dataloaders(batch_size=args.batch_size)
|
29 |
+
if train_loader is None:
|
30 |
+
print("Data ingestion failed.")
|
31 |
+
return
|
32 |
+
else:
|
33 |
+
train_loader = None
|
34 |
+
|
35 |
+
if args.create_model or args.train or args.run_all:
|
36 |
+
print("Creating and testing models...")
|
37 |
+
generator = Generator().to(device)
|
38 |
+
discriminator = Discriminator().to(device)
|
39 |
+
generator.apply(init_weights)
|
40 |
+
discriminator.apply(init_weights)
|
41 |
+
if not test_models():
|
42 |
+
print("Model creation or testing failed.")
|
43 |
+
return
|
44 |
+
else:
|
45 |
+
generator = None
|
46 |
+
discriminator = None
|
47 |
+
|
48 |
+
if args.train or args.run_all:
|
49 |
+
print("Starting model training...")
|
50 |
+
if train_loader is None:
|
51 |
+
print("Creating dataloader for training...")
|
52 |
+
train_loader = create_dataloaders(batch_size=args.batch_size)
|
53 |
+
if train_loader is None:
|
54 |
+
print("Failed to create dataloader for training.")
|
55 |
+
return
|
56 |
+
if generator is None or discriminator is None:
|
57 |
+
print("Creating models for training...")
|
58 |
+
generator = Generator().to(device)
|
59 |
+
discriminator = Discriminator().to(device)
|
60 |
+
generator.apply(init_weights)
|
61 |
+
discriminator.apply(init_weights)
|
62 |
+
run_id = train(generator, discriminator, train_loader, num_epochs=args.num_epochs, device=device)
|
63 |
+
if run_id:
|
64 |
+
print(f"Training completed. Run ID: {run_id}")
|
65 |
+
with open("latest_run_id.txt", "w") as f:
|
66 |
+
f.write(run_id)
|
67 |
+
else:
|
68 |
+
print("Training failed.")
|
69 |
+
return
|
70 |
+
|
71 |
+
if args.test_training:
|
72 |
+
print("Testing training process...")
|
73 |
+
if train_loader is None:
|
74 |
+
print("Creating dataloader for testing...")
|
75 |
+
train_loader = create_dataloaders(batch_size=args.batch_size)
|
76 |
+
if train_loader is None:
|
77 |
+
print("Failed to create dataloader for testing.")
|
78 |
+
return
|
79 |
+
if generator is None or discriminator is None:
|
80 |
+
print("Creating models for testing...")
|
81 |
+
generator = Generator().to(device)
|
82 |
+
discriminator = Discriminator().to(device)
|
83 |
+
generator.apply(init_weights)
|
84 |
+
discriminator.apply(init_weights)
|
85 |
+
if test_training(generator, discriminator, train_loader, device):
|
86 |
+
print("Training process test passed.")
|
87 |
+
else:
|
88 |
+
print("Training process test failed.")
|
89 |
+
|
90 |
+
if args.serve or args.run_all:
|
91 |
+
print("Setting up Gradio app for serving...")
|
92 |
+
if not args.run_id:
|
93 |
+
try:
|
94 |
+
with open("latest_run_id.txt", "r") as f:
|
95 |
+
args.run_id = f.read().strip()
|
96 |
+
except FileNotFoundError:
|
97 |
+
print("No run ID provided and couldn't find latest_run_id.txt")
|
98 |
+
return
|
99 |
+
iface = setup_gradio_app(args.run_id, device)
|
100 |
+
iface.launch(share=args.share)
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
parser = argparse.ArgumentParser(description="Run Colorizer Pipeline")
|
104 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
|
105 |
+
help="Device to use (cuda/cpu)")
|
106 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")
|
107 |
+
parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs to train")
|
108 |
+
parser.add_argument("--run_id", type=str, help="MLflow run ID of the trained model for inference")
|
109 |
+
parser.add_argument("--ingest_data", action="store_true", help="Run data ingestion")
|
110 |
+
parser.add_argument("--create_model", action="store_true", help="Create and test the model")
|
111 |
+
parser.add_argument("--train", action="store_true", help="Train the model")
|
112 |
+
parser.add_argument("--test_training", action="store_true", help="Test the training process")
|
113 |
+
parser.add_argument("--serve", action="store_true", help="Serve the model using Gradio")
|
114 |
+
parser.add_argument("--run_all", action="store_true", help="Run all steps")
|
115 |
+
parser.add_argument("--share", action="store_true", help="Share the Gradio app publicly")
|
116 |
+
args = parser.parse_args()
|
117 |
+
|
118 |
+
run_pipeline(args)
|
train.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from torchvision.utils import make_grid
|
7 |
+
import mlflow
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from tqdm import tqdm
|
10 |
+
import numpy as np
|
11 |
+
from skimage.color import lab2rgb, rgb2lab
|
12 |
+
import argparse
|
13 |
+
from itertools import islice
|
14 |
+
from PIL import Image
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
|
17 |
+
from data_ingestion import ColorizeIterableDataset, create_dataloaders
|
18 |
+
from model import Generator, Discriminator, init_weights
|
19 |
+
|
20 |
+
EXPERIMENT_NAME = "Colorizer_Experiment"
|
21 |
+
|
22 |
+
def setup_mlflow():
|
23 |
+
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
|
24 |
+
if experiment is None:
|
25 |
+
experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
|
26 |
+
else:
|
27 |
+
experiment_id = experiment.experiment_id
|
28 |
+
return experiment_id
|
29 |
+
|
30 |
+
def lab_to_rgb(L, ab):
|
31 |
+
"""Convert L and ab channels to RGB image"""
|
32 |
+
L = (L + 1.) * 50.
|
33 |
+
ab = ab * 128.
|
34 |
+
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
|
35 |
+
rgb_imgs = []
|
36 |
+
for img in Lab:
|
37 |
+
img_rgb = lab2rgb(img)
|
38 |
+
rgb_imgs.append(img_rgb)
|
39 |
+
return np.stack(rgb_imgs, axis=0)
|
40 |
+
|
41 |
+
def preprocess_image(image_path):
|
42 |
+
img = Image.open(image_path).convert('RGB')
|
43 |
+
img = img.resize((256, 256)) # Resize to a consistent size
|
44 |
+
img_lab = rgb2lab(img)
|
45 |
+
img_lab = (img_lab + [0, 128, 128]) / [100, 255, 255] # Normalize LAB values
|
46 |
+
return img_lab[:,:,0], img_lab[:,:,1:]
|
47 |
+
|
48 |
+
def visualize_results(epoch, generator, train_loader, device):
|
49 |
+
generator.eval()
|
50 |
+
with torch.no_grad():
|
51 |
+
for inputs, real_AB in train_loader:
|
52 |
+
inputs, real_AB = inputs.to(device), real_AB.to(device)
|
53 |
+
fake_AB = generator(inputs)
|
54 |
+
|
55 |
+
fake_rgb = lab_to_rgb(inputs.cpu(), fake_AB.cpu())
|
56 |
+
real_rgb = lab_to_rgb(inputs.cpu(), real_AB.cpu())
|
57 |
+
|
58 |
+
img_grid = make_grid(torch.from_numpy(np.concatenate([real_rgb, fake_rgb], axis=3)).permute(0, 3, 1, 2), normalize=True, nrow=4)
|
59 |
+
|
60 |
+
plt.figure(figsize=(15, 15))
|
61 |
+
plt.imshow(img_grid.permute(1, 2, 0).cpu())
|
62 |
+
plt.axis('off')
|
63 |
+
plt.title(f'Epoch {epoch}')
|
64 |
+
plt.savefig(f'results/epoch_{epoch}.png')
|
65 |
+
mlflow.log_artifact(f'results/epoch_{epoch}.png')
|
66 |
+
plt.close()
|
67 |
+
break # Only visualize one batch
|
68 |
+
generator.train()
|
69 |
+
|
70 |
+
def save_checkpoint(state, filename="checkpoint.pth.tar"):
|
71 |
+
torch.save(state, filename)
|
72 |
+
mlflow.log_artifact(filename)
|
73 |
+
|
74 |
+
def load_checkpoint(filename, generator, discriminator, optimizerG, optimizerD):
|
75 |
+
if os.path.isfile(filename):
|
76 |
+
print(f"Loading checkpoint '{filename}'")
|
77 |
+
checkpoint = torch.load(filename)
|
78 |
+
start_epoch = checkpoint['epoch'] + 1
|
79 |
+
generator.load_state_dict(checkpoint['generator_state_dict'])
|
80 |
+
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
|
81 |
+
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
|
82 |
+
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
|
83 |
+
print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})")
|
84 |
+
return start_epoch
|
85 |
+
else:
|
86 |
+
print(f"No checkpoint found at '{filename}'")
|
87 |
+
return 0
|
88 |
+
|
89 |
+
def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002, beta1=0.5):
|
90 |
+
criterion = nn.BCEWithLogitsLoss()
|
91 |
+
l1_loss = nn.L1Loss()
|
92 |
+
|
93 |
+
optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
|
94 |
+
optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
|
95 |
+
|
96 |
+
checkpoint_dir = "checkpoints"
|
97 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
98 |
+
os.makedirs("results", exist_ok=True)
|
99 |
+
|
100 |
+
checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth.tar")
|
101 |
+
start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, optimizerG, optimizerD)
|
102 |
+
|
103 |
+
experiment_id = setup_mlflow()
|
104 |
+
with mlflow.start_run(experiment_id=experiment_id, run_name="training_run") as run:
|
105 |
+
try:
|
106 |
+
for epoch in range(start_epoch, num_epochs):
|
107 |
+
generator.train()
|
108 |
+
discriminator.train()
|
109 |
+
|
110 |
+
# Use a fixed number of iterations per epoch
|
111 |
+
num_iterations = 1000
|
112 |
+
pbar = tqdm(enumerate(islice(train_loader, num_iterations)), total=num_iterations, desc=f"Epoch {epoch+1}/{num_epochs}")
|
113 |
+
|
114 |
+
for i, (real_L, real_AB) in pbar:
|
115 |
+
real_L, real_AB = real_L.to(device), real_AB.to(device)
|
116 |
+
batch_size = real_L.size(0)
|
117 |
+
|
118 |
+
# Train Discriminator
|
119 |
+
optimizerD.zero_grad()
|
120 |
+
|
121 |
+
fake_AB = generator(real_L)
|
122 |
+
fake_LAB = torch.cat([real_L, fake_AB], dim=1)
|
123 |
+
real_LAB = torch.cat([real_L, real_AB], dim=1)
|
124 |
+
|
125 |
+
pred_fake = discriminator(fake_LAB.detach())
|
126 |
+
loss_D_fake = criterion(pred_fake, torch.zeros_like(pred_fake))
|
127 |
+
|
128 |
+
pred_real = discriminator(real_LAB)
|
129 |
+
loss_D_real = criterion(pred_real, torch.ones_like(pred_real))
|
130 |
+
|
131 |
+
loss_D = (loss_D_fake + loss_D_real) * 0.5
|
132 |
+
loss_D.backward()
|
133 |
+
optimizerD.step()
|
134 |
+
|
135 |
+
# Train Generator
|
136 |
+
optimizerG.zero_grad()
|
137 |
+
|
138 |
+
fake_AB = generator(real_L)
|
139 |
+
fake_LAB = torch.cat([real_L, fake_AB], dim=1)
|
140 |
+
pred_fake = discriminator(fake_LAB)
|
141 |
+
|
142 |
+
loss_G_GAN = criterion(pred_fake, torch.ones_like(pred_fake))
|
143 |
+
loss_G_L1 = l1_loss(fake_AB, real_AB) * 100 # L1 loss weight
|
144 |
+
|
145 |
+
loss_G = loss_G_GAN + loss_G_L1
|
146 |
+
loss_G.backward()
|
147 |
+
optimizerG.step()
|
148 |
+
|
149 |
+
pbar.set_postfix({
|
150 |
+
'D_loss': loss_D.item(),
|
151 |
+
'G_loss': loss_G.item(),
|
152 |
+
'G_L1': loss_G_L1.item()
|
153 |
+
})
|
154 |
+
|
155 |
+
mlflow.log_metrics({
|
156 |
+
"D_loss": loss_D.item(),
|
157 |
+
"G_loss": loss_G.item(),
|
158 |
+
"G_L1_loss": loss_G_L1.item()
|
159 |
+
}, step=epoch * num_iterations + i)
|
160 |
+
|
161 |
+
visualize_results(epoch, generator, train_loader, device)
|
162 |
+
|
163 |
+
checkpoint = {
|
164 |
+
'epoch': epoch,
|
165 |
+
'generator_state_dict': generator.state_dict(),
|
166 |
+
'discriminator_state_dict': discriminator.state_dict(),
|
167 |
+
'optimizerG_state_dict': optimizerG.state_dict(),
|
168 |
+
'optimizerD_state_dict': optimizerD.state_dict(),
|
169 |
+
}
|
170 |
+
save_checkpoint(checkpoint, filename=checkpoint_path)
|
171 |
+
|
172 |
+
print("Training completed successfully.")
|
173 |
+
|
174 |
+
# Log the generator model
|
175 |
+
mlflow.pytorch.log_model(generator, "generator_model")
|
176 |
+
|
177 |
+
# Register the model
|
178 |
+
model_uri = f"runs:/{run.info.run_id}/generator_model"
|
179 |
+
mlflow.register_model(model_uri, "colorizer_generator")
|
180 |
+
|
181 |
+
return run.info.run_id
|
182 |
+
|
183 |
+
except Exception as e:
|
184 |
+
print(f"Error during training: {str(e)}")
|
185 |
+
mlflow.log_param("error", str(e))
|
186 |
+
return None
|
187 |
+
|
188 |
+
def test_training(generator, discriminator, train_loader, device):
|
189 |
+
print("Testing training process...")
|
190 |
+
try:
|
191 |
+
train(generator, discriminator, train_loader, num_epochs=1, device=device)
|
192 |
+
print("Training process test passed.")
|
193 |
+
return True
|
194 |
+
except Exception as e:
|
195 |
+
print(f"Training process test failed: {str(e)}")
|
196 |
+
return False
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
parser = argparse.ArgumentParser(description="Train Colorizer model")
|
200 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
|
201 |
+
help="Device to use for training (cuda/cpu)")
|
202 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")
|
203 |
+
parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs to train")
|
204 |
+
parser.add_argument("--test", action="store_true", help="Run in test mode")
|
205 |
+
args = parser.parse_args()
|
206 |
+
|
207 |
+
device = torch.device(args.device)
|
208 |
+
print(f"Using device: {device}")
|
209 |
+
|
210 |
+
try:
|
211 |
+
train_loader = create_dataloaders(batch_size=args.batch_size)
|
212 |
+
|
213 |
+
generator = Generator().to(device)
|
214 |
+
discriminator = Discriminator().to(device)
|
215 |
+
|
216 |
+
generator.apply(init_weights)
|
217 |
+
discriminator.apply(init_weights)
|
218 |
+
|
219 |
+
if args.test:
|
220 |
+
if test_training(generator, discriminator, train_loader, device):
|
221 |
+
print("All tests passed.")
|
222 |
+
else:
|
223 |
+
print("Tests failed.")
|
224 |
+
else:
|
225 |
+
run_id = train(generator, discriminator, train_loader, num_epochs=args.num_epochs, device=device)
|
226 |
+
if run_id:
|
227 |
+
print(f"Training completed. Run ID: {run_id}")
|
228 |
+
# Save the run ID to a file for easy access by the inference script
|
229 |
+
with open("latest_run_id.txt", "w") as f:
|
230 |
+
f.write(run_id)
|
231 |
+
else:
|
232 |
+
print("Training failed.")
|
233 |
+
|
234 |
+
except Exception as e:
|
235 |
+
print(f"Critical error in main execution: {str(e)}")
|