import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel class AbstractEncoder(nn.Module): def __init__(self): super().__init__() def encode(self, *args, **kwargs): raise NotImplementedError class IdentityEncoder(AbstractEncoder): def encode(self, x): return x class ClassEmbedder(nn.Module): def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): super().__init__() self.key = key self.embedding = nn.Embedding(n_classes, embed_dim) self.n_classes = n_classes self.ucg_rate = ucg_rate def forward(self, batch, key=None, disable_dropout=False): if key is None: key = self.key # this is for use in crossattn c = batch[key][:, None] if self.ucg_rate > 0. and not disable_dropout: mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) c = c.long() c = self.embedding(c) return c def get_unconditional_conditioning(self, bs, device="cuda"): uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) uc = torch.ones((bs,), device=device) * uc_class uc = {self.key: uc} return uc class DanbooruEmbedder(AbstractEncoder): def __init__(self): super().__init__()