Spaces:
Runtime error
Runtime error
import os | |
import uuid | |
import redis | |
import torch | |
import scipy | |
from transformers import ( | |
pipeline, AutoTokenizer, AutoModelForCausalLM, AutoProcessor, | |
MusicgenForConditionalGeneration, WhisperProcessor, WhisperForConditionalGeneration, | |
MarianMTModel, MarianTokenizer, BartTokenizer, BartForConditionalGeneration | |
) | |
from diffusers import ( | |
FluxPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler, | |
StableDiffusionImg2ImgPipeline, DiffusionPipeline | |
) | |
from diffusers.utils import export_to_video | |
from datasets import load_dataset | |
from PIL import Image | |
import gradio as gr | |
from dotenv import load_dotenv | |
import multiprocessing | |
load_dotenv() | |
redis_client = redis.Redis( | |
host=os.getenv('REDIS_HOST'), | |
port=os.getenv('REDIS_PORT'), | |
redis_password=os.getenv("REDIS_PASSWORD") | |
) | |
huggingface_token = os.getenv('HF_TOKEN') | |
def generate_unique_id(): | |
return str(uuid.uuid4()) | |
def store_special_tokens(tokenizer, model_name): | |
special_tokens = { | |
'pad_token': tokenizer.pad_token, | |
'pad_token_id': tokenizer.pad_token_id, | |
'eos_token': tokenizer.eos_token, | |
'eos_token_id': tokenizer.eos_token_id, | |
'unk_token': tokenizer.unk_token, | |
'unk_token_id': tokenizer.unk_token_id, | |
'bos_token': tokenizer.bos_token, | |
'bos_token_id': tokenizer.bos_token_id | |
} | |
redis_client.hmset(f"tokenizer_special_tokens:{model_name}", special_tokens) | |
def load_special_tokens(tokenizer, model_name): | |
special_tokens = redis_client.hgetall(f"tokenizer_special_tokens:{model_name}") | |
if special_tokens: | |
tokenizer.pad_token = special_tokens.get('pad_token') | |
tokenizer.pad_token_id = int(special_tokens.get('pad_token_id', -1)) | |
tokenizer.eos_token = special_tokens.get('eos_token') | |
tokenizer.eos_token_id = int(special_tokens.get('eos_token_id', -1)) | |
tokenizer.unk_token = special_tokens.get('unk_token') | |
tokenizer.unk_token_id = int(special_tokens.get('unk_token_id', -1)) | |
tokenizer.bos_token = special_tokens.get('bos_token') | |
tokenizer.bos_token_id = int(special_tokens.get('bos_token_id', -1)) | |
def train_and_store_transformers_model(model_name, data): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model.train() | |
store_special_tokens(tokenizer, model_name) | |
torch.save(model.state_dict(), "transformers_model.pt") | |
with open("transformers_model.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"transformers_model:{model_name}:state_dict", model_data) | |
tokenizer_data = tokenizer.save_pretrained("transformers_tokenizer") | |
redis_client.set(f"transformers_tokenizer:{model_name}", tokenizer_data) | |
def generate_transformers_response_from_redis(model_name, prompt): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"transformers_model:{model_name}:state_dict") | |
with open("transformers_model.pt", "wb") as f: | |
f.write(model_data) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model.load_state_dict(torch.load("transformers_model.pt")) | |
tokenizer_data = redis_client.get(f"transformers_tokenizer:{model_name}") | |
tokenizer = AutoTokenizer.from_pretrained("transformers_tokenizer") | |
load_special_tokens(tokenizer, model_name) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(inputs.input_ids, max_length=50) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
redis_client.set(f"transformers_response:{unique_id}", response) | |
return response | |
def train_and_store_diffusers_model(model_name, data): | |
pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16) | |
pipe.enable_model_cpu_offload() | |
pipe.train() | |
pipe.save_pretrained("diffusers_model") | |
with open("diffusers_model/flux_pipeline.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"diffusers_model:{model_name}", model_data) | |
def generate_diffusers_image_from_redis(model_name, prompt): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"diffusers_model:{model_name}") | |
with open("diffusers_model/flux_pipeline.pt", "wb") as f: | |
f.write(model_data) | |
pipe = FluxPipeline.from_pretrained("diffusers_model", torch_dtype=torch.bfloat16) | |
pipe.enable_model_cpu_offload() | |
image = pipe(prompt, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, generator=torch.Generator("cpu").manual_seed(0)).images[0] | |
image_path = f"images/diffusers_{unique_id}.png" | |
image.save(image_path) | |
redis_client.set(f"diffusers_image:{unique_id}", image_path) | |
return image | |
def train_and_store_musicgen_model(model_name, data): | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = MusicgenForConditionalGeneration.from_pretrained(model_name) | |
model.train() | |
torch.save(model.state_dict(), "musicgen_model.pt") | |
with open("musicgen_model.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"musicgen_model:{model_name}:state_dict", model_data) | |
processor_data = processor.save_pretrained("musicgen_processor") | |
redis_client.set(f"musicgen_processor:{model_name}", processor_data) | |
def generate_musicgen_audio_from_redis(model_name, text_prompts): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"musicgen_model:{model_name}:state_dict") | |
with open("musicgen_model.pt", "wb") as f: | |
f.write(model_data) | |
model = MusicgenForConditionalGeneration.from_pretrained(model_name) | |
model.load_state_dict(torch.load("musicgen_model.pt")) | |
processor_data = redis_client.get(f"musicgen_processor:{model_name}") | |
processor = AutoProcessor.from_pretrained("musicgen_processor") | |
inputs = processor(text=text_prompts, padding=True, return_tensors="pt") | |
audio_values = model.generate(**inputs, max_new_tokens=256) | |
audio_path = f"audio/musicgen_{unique_id}.wav" | |
scipy.io.wavfile.write(audio_path, rate=audio_values["sampling_rate"], data=audio_values["audio"]) | |
redis_client.set(f"musicgen_audio:{unique_id}", audio_path) | |
return audio_path | |
def train_and_store_stable_diffusion_model(model_name, data): | |
pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe = pipe.to("cuda") | |
pipe.train() | |
pipe.save_pretrained("stable_diffusion_model") | |
with open("stable_diffusion_model/stable_diffusion_pipeline.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"stable_diffusion_model:{model_name}", model_data) | |
def generate_stable_diffusion_image_from_redis(model_name, prompt): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"stable_diffusion_model:{model_name}") | |
with open("stable_diffusion_model/stable_diffusion_pipeline.pt", "wb") as f: | |
f.write(model_data) | |
pipe = StableDiffusionPipeline.from_pretrained("stable_diffusion_model", torch_dtype=torch.float16) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe = pipe.to("cuda") | |
image = pipe(prompt).images[0] | |
image_path = f"images/stable_diffusion_{unique_id}.png" | |
image.save(image_path) | |
redis_client.set(f"stable_diffusion_image:{unique_id}", image_path) | |
return image | |
def train_and_store_img2img_model(model_name, data): | |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_name, torch_dtype=torch.float16) | |
pipe = pipe.to("cuda") | |
pipe.train() | |
pipe.save_pretrained("img2img_model") | |
with open("img2img_model/img2img_pipeline.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"img2img_model:{model_name}", model_data) | |
def generate_img2img_from_redis(model_name, init_image, prompt, strength=0.75): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"img2img_model:{model_name}") | |
with open("img2img_model/img2img_pipeline.pt", "wb") as f: | |
f.write(model_data) | |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("img2img_model", torch_dtype=torch.float16) | |
pipe = pipe.to("cuda") | |
init_image = Image.open(init_image).convert("RGB") | |
image = pipe(prompt=prompt, init_image=init_image, strength=strength).images[0] | |
image_path = f"images/img2img_{unique_id}.png" | |
image.save(image_path) | |
redis_client.set(f"img2img_image:{unique_id}", image_path) | |
return image | |
def train_and_store_marianmt_model(model_name, data): | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
model = MarianMTModel.from_pretrained(model_name) | |
model.train() | |
torch.save(model.state_dict(), "marianmt_model.pt") | |
with open("marianmt_model.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"marianmt_model:{model_name}:state_dict", model_data) | |
tokenizer_data = tokenizer.save_pretrained("marianmt_tokenizer") | |
redis_client.set(f"marianmt_tokenizer:{model_name}", tokenizer_data) | |
def translate_text_from_redis(model_name, text, src_lang, tgt_lang): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"marianmt_model:{model_name}:state_dict") | |
with open("marianmt_model.pt", "wb") as f: | |
f.write(model_data) | |
model = MarianMTModel.from_pretrained(model_name) | |
model.load_state_dict(torch.load("marianmt_model.pt")) | |
tokenizer_data = redis_client.get(f"marianmt_tokenizer:{model_name}") | |
tokenizer = MarianTokenizer.from_pretrained("marianmt_tokenizer") | |
inputs = tokenizer(text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang) | |
translated_tokens = model.generate(**inputs) | |
translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
redis_client.set(f"marianmt_translation:{unique_id}", translation) | |
return translation | |
def train_and_store_bart_model(model_name, data): | |
tokenizer = BartTokenizer.from_pretrained(model_name) | |
model = BartForConditionalGeneration.from_pretrained(model_name) | |
model.train() | |
torch.save(model.state_dict(), "bart_model.pt") | |
with open("bart_model.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"bart_model:{model_name}:state_dict", model_data) | |
tokenizer_data = tokenizer.save_pretrained("bart_tokenizer") | |
redis_client.set(f"bart_tokenizer:{model_name}", tokenizer_data) | |
def summarize_text_from_redis(model_name, text): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"bart_model:{model_name}:state_dict") | |
with open("bart_model.pt", "wb") as f: | |
f.write(model_data) | |
model = BartForConditionalGeneration.from_pretrained(model_name) | |
model.load_state_dict(torch.load("bart_model.pt")) | |
tokenizer_data = redis_client.get(f"bart_tokenizer:{model_name}") | |
tokenizer = BartTokenizer.from_pretrained("bart_tokenizer") | |
load_special_tokens(tokenizer, model_name) | |
inputs = tokenizer(text, return_tensors="pt", truncation=True) | |
summary_ids = model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
redis_client.set(f"bart_summary:{unique_id}", summary) | |
return summary | |
def auto_train_and_store(model_name, task, data): | |
if task == "text-generation": | |
train_and_store_transformers_model(model_name, data) | |
elif task == "diffusers": | |
train_and_store_diffusers_model(model_name, data) | |
elif task == "musicgen": | |
train_and_store_musicgen_model(model_name, data) | |
elif task == "stable-diffusion": | |
train_and_store_stable_diffusion_model(model_name, data) | |
elif task == "img2img": | |
train_and_store_img2img_model(model_name, data) | |
elif task == "translation": | |
train_and_store_marianmt_model(model_name, data) | |
elif task == "summarization": | |
train_and_store_bart_model(model_name, data) | |
def transcribe_audio_from_redis(audio_file): | |
audio_file_path = "audio_file.wav" | |
with open(audio_file_path, "wb") as f: | |
f.write(audio_file) | |
processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
model.config.forced_decoder_ids = None | |
sample = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")[0]["audio"] | |
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features | |
predicted_ids = model.generate(input_features) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
return transcription[0] | |
def generate_image_from_redis(model_name, prompt, model_type): | |
if model_type == "diffusers": | |
image = generate_diffusers_image_from_redis(model_name, prompt) | |
elif model_type == "stable-diffusion": | |
image = generate_stable_diffusion_image_from_redis(model_name, prompt) | |
elif model_type == "img2img": | |
image = generate_img2img_from_redis(model_name, "init_image.png", prompt) | |
return image | |
def generate_video_from_redis(prompt): | |
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_model_cpu_offload() | |
video_frames = pipe(prompt, num_inference_steps=25).frames | |
video_path = export_to_video(video_frames) | |
unique_id = generate_unique_id() | |
redis_client.set(f"video_{unique_id}", video_path) | |
return video_path | |
def generate_random_response(prompts, generator): | |
responses = [] | |
for prompt in prompts: | |
response = generator(prompt, max_length=50)[0]['generated_text'] | |
responses.append(response) | |
return responses | |
def process_parallel(tasks): | |
with multiprocessing.Pool() as pool: | |
results = pool.map(lambda task: task(), tasks) | |
return results | |
def generate_response_from_prompt(prompt, generator): | |
responses = generate_random_response([prompt], generator) | |
return responses[0] | |
def generate_image_from_prompt(prompt, image_type): | |
if image_type == "diffusers": | |
image = generate_diffusers_image_from_redis("diffusers_model_name", prompt) | |
elif image_type == "stable-diffusion": | |
image = generate_stable_diffusion_image_from_redis("stable_diffusion_model_name", prompt) | |
elif image_type == "img2img": | |
image = generate_img2img_from_redis("img2img_model_name", "init_image.png", prompt) | |
return image | |
def gradio_app(): | |
with gr.Blocks() as app: | |
gr.Markdown("## Generaci贸n de Texto con Transformers") | |
with gr.Row(): | |
prompt_text = gr.Textbox(label="Texto de Entrada") | |
text_output = gr.Textbox(label="Respuesta") | |
text_button = gr.Button("Generar Texto") | |
text_button.click(generate_response_from_prompt, inputs=prompt_text, outputs=text_output) | |
gr.Markdown("## Generaci贸n de Im谩genes con Diffusers, Stable Diffusion e Img2Img") | |
with gr.Row(): | |
prompt_image = gr.Textbox(label="Prompt de Imagen") | |
image_type = gr.Dropdown(["diffusers", "stable-diffusion", "img2img"], label="Tipo de Imagen") | |
image_output = gr.Image(type="pil", label="Imagen Generada") | |
image_button = gr.Button("Generar Imagen") | |
image_button.click(generate_image_from_prompt, inputs=[prompt_image, image_type], outputs=image_output) | |
gr.Markdown("## Generaci贸n de Video") | |
with gr.Row(): | |
prompt_video = gr.Textbox(label="Prompt de Video") | |
video_output = gr.Video(type="file", label="Video Generado") | |
video_button = gr.Button("Generar Video") | |
video_button.click(generate_video_from_redis, inputs=prompt_video, outputs=video_output) | |
gr.Markdown("## Generaci贸n de Audio con MusicGen") | |
with gr.Row(): | |
text_prompts_audio = gr.Textbox(label="Prompts de Audio") | |
audio_output = gr.Audio(type="file", label="Audio Generado") | |
audio_button = gr.Button("Generar Audio") | |
audio_button.click(generate_musicgen_audio_from_redis, inputs=text_prompts_audio, outputs=audio_output) | |
gr.Markdown("## Transcripci贸n de Audio con Whisper") | |
with gr.Row(): | |
audio_file = gr.Audio(type="file", label="Archivo de Audio") | |
transcription_output = gr.Textbox(label="Transcripci贸n") | |
audio_button = gr.Button("Transcribir Audio") | |
audio_button.click(transcribe_audio_from_redis, inputs=audio_file, outputs=transcription_output) | |
gr.Markdown("## Traducci贸n de Texto") | |
with gr.Row(): | |
text_input = gr.Textbox(label="Texto a Traducir") | |
translation_output = gr.Textbox(label="Traducci贸n") | |
src_lang_input = gr.Textbox(label="Idioma de Origen", value="en") | |
tgt_lang_input = gr.Textbox(label="Idioma de Destino", value="es") | |
translate_button = gr.Button("Traducir Texto") | |
translate_button.click(translate_text_from_redis, inputs=[text_input, src_lang_input, tgt_lang_input], outputs=translation_output) | |
gr.Markdown("## Resumen de Texto") | |
with gr.Row(): | |
text_to_summarize = gr.Textbox(label="Texto para Resumir") | |
summary_output = gr.Textbox(label="Resumen") | |
summarize_button = gr.Button("Generar Resumen") | |
summarize_button.click(summarize_text_from_redis, inputs=text_to_summarize, outputs=summary_output) | |
app.launch() | |
if __name__ == "__main__": | |
gradio_app() | |