Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
from einops import rearrange | |
import numpy as np | |
from typing import List | |
from models.id_embedding.helpers import get_rep_pos, shift_tensor_dim0 | |
from models.id_embedding.meta_net import StyleVectorizer | |
from models.celeb_embeddings import _get_celeb_embeddings_basis | |
from functools import partial | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import torch.nn.init as init | |
DEFAULT_PLACEHOLDER_TOKEN = ["*"] | |
PROGRESSIVE_SCALE = 2000 | |
def get_clip_token_for_string(tokenizer, string): | |
batch_encoding = tokenizer(string, return_length=True, padding=True, truncation=True, return_overflowing_tokens=False, return_tensors="pt") | |
tokens = batch_encoding["input_ids"] | |
return tokens | |
def get_embedding_for_clip_token(embedder, token): | |
return embedder(token.unsqueeze(0)) | |
class EmbeddingManagerId_adain(nn.Module): | |
def __init__( | |
self, | |
tokenizer, | |
text_encoder, | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), | |
experiment_name = "normal_GAN", | |
num_embeds_per_token: int = 2, | |
loss_type: str = None, | |
mlp_depth: int = 2, | |
token_dim: int = 1024, | |
input_dim: int = 1024, | |
**kwargs | |
): | |
super().__init__() | |
self.device = device | |
self.num_es = num_embeds_per_token | |
self.get_token_for_string = partial(get_clip_token_for_string, tokenizer) | |
self.get_embedding_for_tkn = partial(get_embedding_for_clip_token, text_encoder.text_model.embeddings) | |
self.token_dim = token_dim | |
''' 1. Placeholder mapping dicts ''' | |
self.placeholder_token = self.get_token_for_string("*")[0][1] | |
if experiment_name == "normal_GAN": | |
self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names.txt") | |
elif experiment_name == "man_GAN": | |
self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names_man.txt") | |
elif experiment_name == "woman_GAN": | |
self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names_woman.txt") | |
else: | |
print("Hello, please notice this ^_^") | |
assert 0 | |
print("now experiment_name:", experiment_name) | |
self.celeb_embeddings_mean = self.celeb_embeddings_mean.to(device) | |
self.celeb_embeddings_std = self.celeb_embeddings_std.to(device) | |
self.name_projection_layer = StyleVectorizer(input_dim, self.token_dim * self.num_es, depth=mlp_depth, lr_mul=0.1) | |
self.embedding_discriminator = Embedding_discriminator(self.token_dim * self.num_es, dropout_rate = 0.2) | |
self.adain_mode = 0 | |
def forward( | |
self, | |
tokenized_text, | |
embedded_text, | |
name_batch, | |
random_embeddings = None, | |
timesteps = None, | |
): | |
if tokenized_text is not None: | |
batch_size, n, device = *tokenized_text.shape, tokenized_text.device | |
other_return_dict = {} | |
if random_embeddings is not None: | |
mlp_output_embedding = self.name_projection_layer(random_embeddings) | |
total_embedding = mlp_output_embedding.view(mlp_output_embedding.shape[0], 2, 1024) | |
if self.adain_mode == 0: | |
adained_total_embedding = total_embedding * self.celeb_embeddings_std + self.celeb_embeddings_mean | |
else: | |
adained_total_embedding = total_embedding | |
other_return_dict["total_embedding"] = total_embedding | |
other_return_dict["adained_total_embedding"] = adained_total_embedding | |
if name_batch is not None: | |
if isinstance(name_batch, list): | |
name_tokens = self.get_token_for_string(name_batch)[:, 1:3] | |
name_embeddings = self.get_embedding_for_tkn(name_tokens.to(random_embeddings.device))[0] | |
other_return_dict["name_embeddings"] = name_embeddings | |
else: | |
assert 0 | |
if tokenized_text is not None: | |
placeholder_pos = get_rep_pos(tokenized_text, | |
[self.placeholder_token]) | |
placeholder_pos = np.array(placeholder_pos) | |
if len(placeholder_pos) != 0: | |
batch_size = adained_total_embedding.shape[0] | |
end_index = min(batch_size, placeholder_pos.shape[0]) | |
embedded_text[placeholder_pos[:, 0], placeholder_pos[:, 1]] = adained_total_embedding[:end_index,0,:] | |
embedded_text[placeholder_pos[:, 0], placeholder_pos[:, 1] + 1] = adained_total_embedding[:end_index,1,:] | |
return embedded_text, other_return_dict | |
def load(self, ckpt_path): | |
ckpt = torch.load(ckpt_path, map_location='cuda') | |
if ckpt.get("name_projection_layer") is not None: | |
self.name_projection_layer = ckpt.get("name_projection_layer").float() | |
print('[Embedding Manager] weights loaded.') | |
def save(self, ckpt_path): | |
save_dict = {} | |
save_dict["name_projection_layer"] = self.name_projection_layer | |
torch.save(save_dict, ckpt_path) | |
def trainable_projection_parameters(self): | |
trainable_list = [] | |
trainable_list.extend(list(self.name_projection_layer.parameters())) | |
return trainable_list | |
class Embedding_discriminator(nn.Module): | |
def __init__(self, input_size, dropout_rate): | |
super(Embedding_discriminator, self).__init__() | |
self.input_size = input_size | |
self.fc1 = nn.Linear(input_size, 512) | |
self.fc2 = nn.Linear(512, 256) | |
self.fc3 = nn.Linear(256, 1) | |
self.LayerNorm1 = nn.LayerNorm(512) | |
self.LayerNorm2 = nn.LayerNorm(256) | |
self.leaky_relu = nn.LeakyReLU(0.2) | |
self.dropout_rate = dropout_rate | |
if self.dropout_rate > 0: | |
self.dropout1 = nn.Dropout(dropout_rate) | |
self.dropout2 = nn.Dropout(dropout_rate) | |
def forward(self, input): | |
x = input.view(-1, self.input_size) | |
if self.dropout_rate > 0: | |
x = self.leaky_relu(self.dropout1(self.fc1(x))) | |
else: | |
x = self.leaky_relu(self.fc1(x)) | |
if self.dropout_rate > 0: | |
x = self.leaky_relu(self.dropout2(self.fc2(x))) | |
else: | |
x = self.leaky_relu(self.fc2(x)) | |
x = self.fc3(x) | |
return x | |
def save(self, ckpt_path): | |
save_dict = {} | |
save_dict["fc1"] = self.fc1 | |
save_dict["fc2"] = self.fc2 | |
save_dict["fc3"] = self.fc3 | |
save_dict["LayerNorm1"] = self.LayerNorm1 | |
save_dict["LayerNorm2"] = self.LayerNorm2 | |
save_dict["leaky_relu"] = self.leaky_relu | |
save_dict["dropout1"] = self.dropout1 | |
save_dict["dropout2"] = self.dropout2 | |
torch.save(save_dict, ckpt_path) | |
def load(self, ckpt_path): | |
ckpt = torch.load(ckpt_path, map_location='cuda') | |
if ckpt.get("first_name_proj_layer") is not None: | |
self.fc1 = ckpt.get("fc1").float() | |
self.fc2 = ckpt.get("fc2").float() | |
self.fc3 = ckpt.get("fc3").float() | |
self.LayerNorm1 = ckpt.get("LayerNorm1").float() | |
self.LayerNorm2 = ckpt.get("LayerNorm2").float() | |
self.leaky_relu = ckpt.get("leaky_relu").float() | |
self.dropout1 = ckpt.get("dropout1").float() | |
self.dropout2 = ckpt.get("dropout2").float() | |
print('[Embedding D] weights loaded.') | |