Spaces:
Running
on
Zero
Running
on
Zero
Update src/pipelines/pipeline_echo_mimic.py
Browse files
src/pipelines/pipeline_echo_mimic.py
CHANGED
@@ -34,6 +34,7 @@ from transformers import CLIPImageProcessor
|
|
34 |
from src.models.mutual_self_attention import ReferenceAttentionControl
|
35 |
from src.pipelines.context import get_context_scheduler
|
36 |
from src.pipelines.utils import get_tensor_interpolation_method
|
|
|
37 |
|
38 |
@dataclass
|
39 |
class Audio2VideoPipelineOutput(BaseOutput):
|
@@ -417,9 +418,9 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
417 |
generator
|
418 |
)
|
419 |
# print(video_length, latents.shape)
|
420 |
-
|
421 |
-
uc_face_locator_tensor = torch.zeros_like(
|
422 |
-
face_locator_tensor = torch.cat([uc_face_locator_tensor,
|
423 |
# Prepare extra step kwargs.
|
424 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
425 |
|
@@ -474,7 +475,7 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
474 |
encoder_hidden_states=None,
|
475 |
return_dict=False,
|
476 |
)
|
477 |
-
reference_control_reader.update(reference_control_writer, do_classifier_free_guidance=
|
478 |
|
479 |
|
480 |
num_context_batches = math.ceil(len(context_queue) / context_batch_size)
|
@@ -498,8 +499,8 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
498 |
.to(device)
|
499 |
.repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
|
500 |
)
|
501 |
-
|
502 |
-
audio_latents = torch.cat([torch.zeros_like(
|
503 |
|
504 |
latent_model_input = self.scheduler.scale_model_input(
|
505 |
latent_model_input, t
|
@@ -508,11 +509,15 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
508 |
latent_model_input,
|
509 |
t,
|
510 |
encoder_hidden_states=None,
|
511 |
-
audio_cond_fea=audio_latents,
|
512 |
-
face_musk_fea=face_locator_tensor,
|
513 |
return_dict=False,
|
514 |
)[0]
|
515 |
|
|
|
|
|
|
|
|
|
516 |
for j, c in enumerate(new_context):
|
517 |
noise_pred[:, :, c] = noise_pred[:, :, c] + pred
|
518 |
counter[:, :, c] = counter[:, :, c] + 1
|
@@ -523,6 +528,8 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
523 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
524 |
noise_pred_text - noise_pred_uncond
|
525 |
)
|
|
|
|
|
526 |
|
527 |
latents = self.scheduler.step(
|
528 |
noise_pred, t, latents, **extra_step_kwargs
|
@@ -583,4 +590,4 @@ class Audio2VideoPipeline(DiffusionPipeline):
|
|
583 |
smoothed_tensor = torch.cat(
|
584 |
[tensor[:, :, 0:1, :, :], internal_frames, tensor[:, :, -1:, :, :]], dim=2)
|
585 |
|
586 |
-
return smoothed_tensor
|
|
|
34 |
from src.models.mutual_self_attention import ReferenceAttentionControl
|
35 |
from src.pipelines.context import get_context_scheduler
|
36 |
from src.pipelines.utils import get_tensor_interpolation_method
|
37 |
+
from src.utils.step_func import origin_by_velocity_and_sample, psuedo_velocity_wrt_noisy_and_timestep, get_alpha
|
38 |
|
39 |
@dataclass
|
40 |
class Audio2VideoPipelineOutput(BaseOutput):
|
|
|
418 |
generator
|
419 |
)
|
420 |
# print(video_length, latents.shape)
|
421 |
+
c_face_locator_tensor = self.face_locator(face_mask_tensor)
|
422 |
+
uc_face_locator_tensor = torch.zeros_like(c_face_locator_tensor)
|
423 |
+
face_locator_tensor = torch.cat([uc_face_locator_tensor, c_face_locator_tensor], dim=0)
|
424 |
# Prepare extra step kwargs.
|
425 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
426 |
|
|
|
475 |
encoder_hidden_states=None,
|
476 |
return_dict=False,
|
477 |
)
|
478 |
+
reference_control_reader.update(reference_control_writer, do_classifier_free_guidance=do_classifier_free_guidance)
|
479 |
|
480 |
|
481 |
num_context_batches = math.ceil(len(context_queue) / context_batch_size)
|
|
|
499 |
.to(device)
|
500 |
.repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
|
501 |
)
|
502 |
+
c_audio_latents = torch.cat([audio_fea_final[:, c] for c in new_context]).to(device)
|
503 |
+
audio_latents = torch.cat([torch.zeros_like(c_audio_latents), c_audio_latents], 0)
|
504 |
|
505 |
latent_model_input = self.scheduler.scale_model_input(
|
506 |
latent_model_input, t
|
|
|
509 |
latent_model_input,
|
510 |
t,
|
511 |
encoder_hidden_states=None,
|
512 |
+
audio_cond_fea=audio_latents if do_classifier_free_guidance else c_audio_latents,
|
513 |
+
face_musk_fea=face_locator_tensor if do_classifier_free_guidance else c_face_locator_tensor,
|
514 |
return_dict=False,
|
515 |
)[0]
|
516 |
|
517 |
+
alphas_cumprod = self.scheduler.alphas_cumprod.to(latent_model_input.device)
|
518 |
+
x_pred = origin_by_velocity_and_sample(pred, latent_model_input, alphas_cumprod, t)
|
519 |
+
pred = psuedo_velocity_wrt_noisy_and_timestep(latent_model_input, x_pred, alphas_cumprod, t, torch.ones_like(t) * (-1))
|
520 |
+
|
521 |
for j, c in enumerate(new_context):
|
522 |
noise_pred[:, :, c] = noise_pred[:, :, c] + pred
|
523 |
counter[:, :, c] = counter[:, :, c] + 1
|
|
|
528 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
529 |
noise_pred_text - noise_pred_uncond
|
530 |
)
|
531 |
+
else:
|
532 |
+
noise_pred = noise_pred / counter
|
533 |
|
534 |
latents = self.scheduler.step(
|
535 |
noise_pred, t, latents, **extra_step_kwargs
|
|
|
590 |
smoothed_tensor = torch.cat(
|
591 |
[tensor[:, :, 0:1, :, :], internal_frames, tensor[:, :, -1:, :, :]], dim=2)
|
592 |
|
593 |
+
return smoothed_tensor
|