sgerard's picture
Initial commit of working version
# Code adapted from the following sources:
import torch
from PIL import Image
from models import Generator
def load_img_generator(model_name_or_path):
generator = Generator(in_channels=256, out_channels=3)
generator = generator.from_pretrained(model_name_or_path, in_channels=256, out_channels=3)
_ = generator.eval()
return generator
def _denormalize(input: torch.Tensor) -> torch.Tensor:
return (input * 127.5) + 127.5
def generate_img(device, gan_model):
img_generator = load_img_generator("huggan/fastgan-few-shot-"+gan_model)
noise = torch.zeros(1, 256, 1, 1, device=device).normal_(0.0, 1.0)
with torch.no_grad():
gan_images, _ = img_generator(noise)
gan_image = _denormalize(gan_images.detach()).cpu().squeeze()
gan_image = gan_image.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
gan_image = Image.fromarray(gan_image)
return gan_image