import numpy as np import torch import gradio as gr from vae import * import matplotlib.image as mpimg with open("vae.pt", "rb") as file: vae = torch.load(file) vae.eval() def generate_image(filename): image = mpimg.imread(filename)[:, :, 0] / 255 grayscale = vae(torch.Tensor(image))[0].reshape((28, 28)) return grayscale.detach().numpy() examples = [f"examples/{i}.jpg" for i in range(10)] demo = gr.Interface(generate_image, gr.Image(type="filepath"), "image", examples, title="VAE running on Fashion MNIST", description=".", article="...", allow_flagging=False, ) demo.launch()