import numpy as np from pyrolens_deployment.gradio_app.dehazing_gen import CycleGenerator import torch from torchvision import transforms import gradio as gr gan = CycleGenerator(num_residuals=6) gan.load_state_dict(torch.load("genC.pth.tar", map_location=torch.device('cpu'))) def dehaze(img): gan_transforms = transforms.Compose([ transforms.Resize((800, 800)), transforms.ToTensor() ]) dehazed_output = gan(gan_transforms(img)) out_arr = dehazed_output.detach().cpu() return np.array(out_arr).transpose(1, 2, 0) sample_images = [ ("Haze", "gradio_check1.png"), ("Haze", "gradio_check10.png"), ("Haze", "gradio_check13.png"), ] with gr.Blocks() as demo: gr.Markdown("# ClarityGAN") gr.Markdown("## Image Dehazing using CycleGANs") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil") with gr.Row(): dehaze_button = gr.Button("Dehaze") with gr.Column(): output_image = gr.Image(label="Output Image", type="pil") for name, file in sample_images: gr.Button(name).click(dehaze, inputs=input_image, outputs=output_image) demo.launch()