Spaces:
ismot
/
Runtime error

hel1 / utils.py
ismot's picture
Duplicate from huggan/FastGan
40e6a61
import torch
import torch.nn as nn
from enum import Enum
import base64
import json
from io import BytesIO
from PIL import Image
import requests
import re
from copy import deepcopy
class ImageType(Enum):
REAL_UP_L = 0
REAL_UP_R = 1
REAL_DOWN_R = 2
REAL_DOWN_L = 3
FAKE = 4
def crop_image_part(image: torch.Tensor,
part: ImageType) -> torch.Tensor:
size = image.shape[2] // 2
if part == ImageType.REAL_UP_L:
return image[:, :, :size, :size]
elif part == ImageType.REAL_UP_R:
return image[:, :, :size, size:]
elif part == ImageType.REAL_DOWN_L:
return image[:, :, size:, :size]
elif part == ImageType.REAL_DOWN_R:
return image[:, :, size:, size:]
else:
raise ValueError('invalid part')
def init_weights(module: nn.Module):
if isinstance(module, nn.Conv2d):
torch.nn.init.normal_(module.weight, 0.0, 0.02)
if isinstance(module, nn.BatchNorm2d):
torch.nn.init.normal_(module.weight, 1.0, 0.02)
module.bias.data.fill_(0)
def load_image_from_local(image_path, image_resize=None):
image = Image.open(image_path)
if isinstance(image_resize, tuple):
image = image.resize(image_resize)
return image
def load_image_from_url(image_url, rgba_mode=False, image_resize=None, default_image=None):
try:
image = Image.open(requests.get(image_url, stream=True).raw)
if rgba_mode:
image = image.convert("RGBA")
if isinstance(image_resize, tuple):
image = image.resize(image_resize)
except Exception as e:
image = None
if default_image:
image = load_image_from_local(default_image, image_resize=image_resize)
return image
def image_to_base64(image_array):
buffered = BytesIO()
image_array.save(buffered, format="PNG")
image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/png;base64, {image_b64}"
def copy_G_params(model):
flatten = deepcopy(list(p.data for p in model.parameters()))
return flatten
def load_params(model, new_param):
for p, new_p in zip(model.parameters(), new_param):
p.data.copy_(new_p)