File size: 5,512 Bytes
d52990b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0054c9
 
 
 
 
d52990b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b065204
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        return self.relu(out)

class EnhancedGarmentNet(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, num_residual_blocks=4):
        super(EnhancedGarmentNet, self).__init__()
        
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True)
        )
        
        self.encoder1 = self._make_layer(base_channels, base_channels, num_residual_blocks)
        self.encoder2 = self._make_layer(base_channels, base_channels*2, num_residual_blocks)
        self.encoder3 = self._make_layer(base_channels*2, base_channels*4, num_residual_blocks)
        
        self.bridge = self._make_layer(base_channels*4, base_channels*8, num_residual_blocks)
        
        self.decoder3 = self._make_layer(base_channels*8, base_channels*4, num_residual_blocks)
        self.decoder2 = self._make_layer(base_channels*4, base_channels*2, num_residual_blocks)
        self.decoder1 = self._make_layer(base_channels*2, base_channels, num_residual_blocks)
        
        self.final = nn.Conv2d(base_channels, in_channels, kernel_size=7, padding=3)
        
        self.downsample = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def _make_layer(self, in_channels, out_channels, num_blocks):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Initial convolution
        x = self.initial(x)
        
        # Encoder
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.downsample(e1))
        e3 = self.encoder3(self.downsample(e2))
        
        # Bridge
        b = self.bridge(self.downsample(e3))
        
        # Decoder with skip connections
        d3 = self.decoder3(torch.cat([self.upsample(b), e3], dim=1))
        d2 = self.decoder2(torch.cat([self.upsample(d3), e2], dim=1))
        d1 = self.decoder1(torch.cat([self.upsample(d2), e1], dim=1))
        
        # Final convolution
        out = self.final(d1)
        
        return out, [e1, e2, e3, b]

class EnhancedGarmentNetWithTimestep(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, num_residual_blocks=4, time_emb_dim=256):
        super(EnhancedGarmentNetWithTimestep, self).__init__()
        
        self.garment_net = EnhancedGarmentNet(in_channels, base_channels, num_residual_blocks)
        
        # Timestep embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        
        # Projection for text embeddings
        self.text_proj = nn.Linear(768, time_emb_dim)  # Assuming text embeddings are 768-dimensional
        
        # Combine garment features with time and text embeddings
        self.combine = nn.ModuleList([
            nn.Conv2d(base_channels + time_emb_dim, base_channels, kernel_size=1),
            nn.Conv2d(base_channels*2 + time_emb_dim, base_channels*2, kernel_size=1),
            nn.Conv2d(base_channels*4 + time_emb_dim, base_channels*4, kernel_size=1),
            nn.Conv2d(base_channels*8 + time_emb_dim, base_channels*8, kernel_size=1)
        ])

    def forward(self, x, t, text_embeds):
        # Ensure all inputs are of the same dtype
        x = x.to(dtype=self.garment_net.initial[0].weight.dtype)
        t = t.to(dtype=self.garment_net.initial[0].weight.dtype)
        text_embeds = text_embeds.to(dtype=self.garment_net.initial[0].weight.dtype)
        
        # Get garment features
        garment_out, garment_features = self.garment_net(x)
        
        # Process timestep
        t_emb = self.time_mlp(t.unsqueeze(-1)).unsqueeze(-1).unsqueeze(-1)
        
        # Process text embeddings
        text_emb = self.text_proj(text_embeds).unsqueeze(-1).unsqueeze(-1)
        
        # Combine embeddings
        cond_emb = t_emb + text_emb
        
        # Combine garment features with conditional embedding
        combined_features = []
        for feat, comb_layer in zip(garment_features, self.combine):
            # Expand conditional embedding to match feature map size
            expanded_cond_emb = cond_emb.expand(-1, -1, feat.size(2), feat.size(3))
            combined = comb_layer(torch.cat([feat, expanded_cond_emb], dim=1))
            combined_features.append(combined)
        
        return garment_out, combined_features