import os import subprocess import torch import gradio as gr from clip_interrogator import Config, Interrogator CACHE_URLS = [ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl', 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl', 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl', 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl', 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl', ] os.makedirs('cache', exist_ok=True) for url in CACHE_URLS: subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8') config = Config() config.device = 'cuda' if torch.cuda.is_available() else 'cpu' config.blip_offload = False if torch.cuda.is_available() else True config.chunk_size = 2048 config.flavor_intermediate_count = 512 config.blip_num_beams = 64 ci = Interrogator(config) def inference(image, mode, best_max_flavors): image = image.convert('RGB') if mode == 'best': prompt_result = ci.interrogate(image, max_flavors=int(best_max_flavors)) elif mode == 'classic': prompt_result = ci.interrogate_classic(image) else: prompt_result = ci.interrogate_fast(image) return prompt_result with gr.Blocks() as demo: with gr.Column(): gr.Markdown("# CLIP Interrogator") input_image = gr.Image(type='pil', elem_id="input-img") with gr.Row(): mode_input = gr.Radio(['best', 'classic', 'fast'], label='Select mode', value='best') flavor_input = gr.Slider(minimum=2, maximum=48, step=2, value=32, label='best mode max flavors') submit_btn = gr.Button("Submit") output_text = gr.Textbox(label="Description Output") submit_btn.click( fn=inference, inputs=[input_image, mode_input, flavor_input], outputs=[output_text], concurrency_limit=10 ) demo.queue().launch()