import torch from PIL import Image # from strhub.data.module import SceneTextDataModule from torchvision import transforms as T import gradio as gr # Load model and image transforms parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval() # img_transform = SceneTextDataModule.get_transform(parseq.hparams.img_size) transform = T.Compose([ T.Resize(parseq.hparams.img_size, T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(0.5, 0.5) ]) def infer(inps): img = inps.convert('RGB') # Preprocess. Model expects a batch of images with shape: (B, C, H, W) img = transform(img).unsqueeze(0) logits = parseq(img) pred = logits.softmax(-1) label, confidence = parseq.tokenizer.decode(pred) # print('Decoded label = {}'.format(label[0])) return label[0] demo = gr.Interface(fn=infer, inputs=[gr.inputs.Image(type="pil")], outputs=[gr.outputs.Textbox(label="Output Text")] ) demo.launch()