File size: 1,735 Bytes
81d8e7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import Tuple
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from diffusers.models.modeling_utils import ModelMixin
import torch


class Conv2d(nn.Conv2d):
    def forward(self, x):
        x = super().forward(x)
        return x


class DepthGuider(ModelMixin):
    def __init__(
        self,
        conditioning_embedding_channels: int=4,
        conditioning_channels: int = 1,
        block_out_channels: Tuple[int] = (16, 32, 64, 128),
    ):
        super().__init__()
        self.conv_in = Conv2d(
            conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
        )
        self.blocks = nn.ModuleList([])

        for i in range(len(block_out_channels) - 1):
            channel_in = block_out_channels[i]
            channel_out = block_out_channels[i + 1]
            self.blocks.append(
                Conv2d(channel_in, channel_in, kernel_size=3, padding=1)
            )
            self.blocks.append(
                Conv2d(
                    channel_in, channel_out, kernel_size=3, padding=1, stride=2
                )
            )
        self.conv_out = Conv2d(
                block_out_channels[-1],
                conditioning_embedding_channels,
                kernel_size=3,
                padding=1,
        )

    def forward(self, conditioning):
        conditioning = F.interpolate(conditioning, size=(512,512), mode = 'bilinear', align_corners=True)
        embedding = self.conv_in(conditioning)
        embedding = F.silu(embedding)

        for block in self.blocks:
            embedding = block(embedding)
            embedding = F.silu(embedding)

        embedding = self.conv_out(embedding)

        return embedding