hyoungwoncho
commited on
Commit
•
2315bef
1
Parent(s):
22808d8
Update pipeline.py
Browse files- pipeline.py +23 -48
pipeline.py
CHANGED
@@ -38,10 +38,8 @@ EXAMPLE_DOC_STRING = """
|
|
38 |
```py
|
39 |
>>> import torch
|
40 |
>>> from diffusers import StableDiffusionPipeline
|
41 |
-
|
42 |
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
43 |
>>> pipe = pipe.to("cuda")
|
44 |
-
|
45 |
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
46 |
>>> image = pipe(prompt).images[0]
|
47 |
```
|
@@ -64,8 +62,12 @@ class PAGIdentitySelfAttnProcessor:
|
|
64 |
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
65 |
attention_mask: Optional[torch.FloatTensor] = None,
|
66 |
temb: Optional[torch.FloatTensor] = None,
|
67 |
-
|
|
|
68 |
) -> torch.FloatTensor:
|
|
|
|
|
|
|
69 |
|
70 |
residual = hidden_states
|
71 |
if attn.spatial_norm is not None:
|
@@ -91,11 +93,9 @@ class PAGIdentitySelfAttnProcessor:
|
|
91 |
if attn.group_norm is not None:
|
92 |
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
key = attn.to_k(hidden_states_org, *args)
|
98 |
-
value = attn.to_v(hidden_states_org, *args)
|
99 |
|
100 |
inner_dim = key.shape[-1]
|
101 |
head_dim = inner_dim // attn.heads
|
@@ -115,7 +115,7 @@ class PAGIdentitySelfAttnProcessor:
|
|
115 |
hidden_states_org = hidden_states_org.to(query.dtype)
|
116 |
|
117 |
# linear proj
|
118 |
-
hidden_states_org = attn.to_out[0](hidden_states_org
|
119 |
# dropout
|
120 |
hidden_states_org = attn.to_out[1](hidden_states_org)
|
121 |
|
@@ -134,9 +134,7 @@ class PAGIdentitySelfAttnProcessor:
|
|
134 |
if attn.group_norm is not None:
|
135 |
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
value = attn.to_v(hidden_states_ptb, *args)
|
140 |
|
141 |
hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
|
142 |
#hidden_states_ptb = value
|
@@ -144,7 +142,7 @@ class PAGIdentitySelfAttnProcessor:
|
|
144 |
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
145 |
|
146 |
# linear proj
|
147 |
-
hidden_states_ptb = attn.to_out[0](hidden_states_ptb
|
148 |
# dropout
|
149 |
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
150 |
|
@@ -178,8 +176,12 @@ class PAGCFGIdentitySelfAttnProcessor:
|
|
178 |
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
179 |
attention_mask: Optional[torch.FloatTensor] = None,
|
180 |
temb: Optional[torch.FloatTensor] = None,
|
181 |
-
|
|
|
182 |
) -> torch.FloatTensor:
|
|
|
|
|
|
|
183 |
|
184 |
residual = hidden_states
|
185 |
if attn.spatial_norm is not None:
|
@@ -205,12 +207,10 @@ class PAGCFGIdentitySelfAttnProcessor:
|
|
205 |
|
206 |
if attn.group_norm is not None:
|
207 |
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
208 |
-
|
209 |
-
args = () if USE_PEFT_BACKEND else (scale,)
|
210 |
|
211 |
-
query = attn.to_q(hidden_states_org
|
212 |
-
key = attn.to_k(hidden_states_org
|
213 |
-
value = attn.to_v(hidden_states_org
|
214 |
|
215 |
inner_dim = key.shape[-1]
|
216 |
head_dim = inner_dim // attn.heads
|
@@ -230,7 +230,7 @@ class PAGCFGIdentitySelfAttnProcessor:
|
|
230 |
hidden_states_org = hidden_states_org.to(query.dtype)
|
231 |
|
232 |
# linear proj
|
233 |
-
hidden_states_org = attn.to_out[0](hidden_states_org
|
234 |
# dropout
|
235 |
hidden_states_org = attn.to_out[1](hidden_states_org)
|
236 |
|
@@ -249,14 +249,12 @@ class PAGCFGIdentitySelfAttnProcessor:
|
|
249 |
if attn.group_norm is not None:
|
250 |
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
251 |
|
252 |
-
|
253 |
-
|
254 |
-
value = attn.to_v(hidden_states_ptb, *args)
|
255 |
hidden_states_ptb = value
|
256 |
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
257 |
|
258 |
# linear proj
|
259 |
-
hidden_states_ptb = attn.to_out[0](hidden_states_ptb
|
260 |
# dropout
|
261 |
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
262 |
|
@@ -298,7 +296,6 @@ def retrieve_timesteps(
|
|
298 |
"""
|
299 |
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
300 |
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
301 |
-
|
302 |
Args:
|
303 |
scheduler (`SchedulerMixin`):
|
304 |
The scheduler to get timesteps from.
|
@@ -311,7 +308,6 @@ def retrieve_timesteps(
|
|
311 |
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
312 |
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
313 |
must be `None`.
|
314 |
-
|
315 |
Returns:
|
316 |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
317 |
second element is the number of inference steps.
|
@@ -332,22 +328,19 @@ def retrieve_timesteps(
|
|
332 |
return timesteps, num_inference_steps
|
333 |
|
334 |
|
335 |
-
class
|
336 |
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
|
337 |
):
|
338 |
r"""
|
339 |
Pipeline for text-to-image generation using Stable Diffusion.
|
340 |
-
|
341 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
342 |
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
343 |
-
|
344 |
The pipeline also inherits the following loading methods:
|
345 |
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
346 |
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
347 |
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
348 |
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
349 |
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
350 |
-
|
351 |
Args:
|
352 |
vae ([`AutoencoderKL`]):
|
353 |
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
@@ -540,7 +533,6 @@ class StableDiffusionPAGPipeline(
|
|
540 |
):
|
541 |
r"""
|
542 |
Encodes the prompt into text encoder hidden states.
|
543 |
-
|
544 |
Args:
|
545 |
prompt (`str` or `List[str]`, *optional*):
|
546 |
prompt to be encoded
|
@@ -885,12 +877,9 @@ class StableDiffusionPAGPipeline(
|
|
885 |
|
886 |
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
887 |
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
888 |
-
|
889 |
The suffixes after the scaling factors represent the stages where they are being applied.
|
890 |
-
|
891 |
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
892 |
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
893 |
-
|
894 |
Args:
|
895 |
s1 (`float`):
|
896 |
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
@@ -914,13 +903,9 @@ class StableDiffusionPAGPipeline(
|
|
914 |
"""
|
915 |
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
916 |
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
917 |
-
|
918 |
<Tip warning={true}>
|
919 |
-
|
920 |
This API is 🧪 experimental.
|
921 |
-
|
922 |
</Tip>
|
923 |
-
|
924 |
Args:
|
925 |
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
926 |
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
@@ -944,17 +929,12 @@ class StableDiffusionPAGPipeline(
|
|
944 |
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
|
945 |
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
946 |
"""Disable QKV projection fusion if enabled.
|
947 |
-
|
948 |
<Tip warning={true}>
|
949 |
-
|
950 |
This API is 🧪 experimental.
|
951 |
-
|
952 |
</Tip>
|
953 |
-
|
954 |
Args:
|
955 |
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
956 |
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
957 |
-
|
958 |
"""
|
959 |
if unet:
|
960 |
if not self.fusing_unet:
|
@@ -974,7 +954,6 @@ class StableDiffusionPAGPipeline(
|
|
974 |
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
975 |
"""
|
976 |
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
977 |
-
|
978 |
Args:
|
979 |
timesteps (`torch.Tensor`):
|
980 |
generate embedding vectors at these timesteps
|
@@ -982,7 +961,6 @@ class StableDiffusionPAGPipeline(
|
|
982 |
dimension of the embeddings to generate
|
983 |
dtype:
|
984 |
data type of the generated embeddings
|
985 |
-
|
986 |
Returns:
|
987 |
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
988 |
"""
|
@@ -1128,7 +1106,6 @@ class StableDiffusionPAGPipeline(
|
|
1128 |
):
|
1129 |
r"""
|
1130 |
The call function to the pipeline for generation.
|
1131 |
-
|
1132 |
Args:
|
1133 |
prompt (`str` or `List[str]`, *optional*):
|
1134 |
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
@@ -1195,9 +1172,7 @@ class StableDiffusionPAGPipeline(
|
|
1195 |
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
1196 |
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
1197 |
`._callback_tensor_inputs` attribute of your pipeline class.
|
1198 |
-
|
1199 |
Examples:
|
1200 |
-
|
1201 |
Returns:
|
1202 |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1203 |
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
|
|
38 |
```py
|
39 |
>>> import torch
|
40 |
>>> from diffusers import StableDiffusionPipeline
|
|
|
41 |
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
42 |
>>> pipe = pipe.to("cuda")
|
|
|
43 |
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
44 |
>>> image = pipe(prompt).images[0]
|
45 |
```
|
|
|
62 |
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
63 |
attention_mask: Optional[torch.FloatTensor] = None,
|
64 |
temb: Optional[torch.FloatTensor] = None,
|
65 |
+
*args,
|
66 |
+
**kwargs,
|
67 |
) -> torch.FloatTensor:
|
68 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
69 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
70 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
71 |
|
72 |
residual = hidden_states
|
73 |
if attn.spatial_norm is not None:
|
|
|
93 |
if attn.group_norm is not None:
|
94 |
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
95 |
|
96 |
+
query = attn.to_q(hidden_states_org)
|
97 |
+
key = attn.to_k(hidden_states_org)
|
98 |
+
value = attn.to_v(hidden_states_org)
|
|
|
|
|
99 |
|
100 |
inner_dim = key.shape[-1]
|
101 |
head_dim = inner_dim // attn.heads
|
|
|
115 |
hidden_states_org = hidden_states_org.to(query.dtype)
|
116 |
|
117 |
# linear proj
|
118 |
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
119 |
# dropout
|
120 |
hidden_states_org = attn.to_out[1](hidden_states_org)
|
121 |
|
|
|
134 |
if attn.group_norm is not None:
|
135 |
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
136 |
|
137 |
+
value = attn.to_v(hidden_states_ptb)
|
|
|
|
|
138 |
|
139 |
hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
|
140 |
#hidden_states_ptb = value
|
|
|
142 |
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
143 |
|
144 |
# linear proj
|
145 |
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
146 |
# dropout
|
147 |
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
148 |
|
|
|
176 |
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
177 |
attention_mask: Optional[torch.FloatTensor] = None,
|
178 |
temb: Optional[torch.FloatTensor] = None,
|
179 |
+
*args,
|
180 |
+
**kwargs,
|
181 |
) -> torch.FloatTensor:
|
182 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
183 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
184 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
185 |
|
186 |
residual = hidden_states
|
187 |
if attn.spatial_norm is not None:
|
|
|
207 |
|
208 |
if attn.group_norm is not None:
|
209 |
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
|
|
|
|
210 |
|
211 |
+
query = attn.to_q(hidden_states_org)
|
212 |
+
key = attn.to_k(hidden_states_org)
|
213 |
+
value = attn.to_v(hidden_states_org)
|
214 |
|
215 |
inner_dim = key.shape[-1]
|
216 |
head_dim = inner_dim // attn.heads
|
|
|
230 |
hidden_states_org = hidden_states_org.to(query.dtype)
|
231 |
|
232 |
# linear proj
|
233 |
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
234 |
# dropout
|
235 |
hidden_states_org = attn.to_out[1](hidden_states_org)
|
236 |
|
|
|
249 |
if attn.group_norm is not None:
|
250 |
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
251 |
|
252 |
+
value = attn.to_v(hidden_states_ptb)
|
|
|
|
|
253 |
hidden_states_ptb = value
|
254 |
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
255 |
|
256 |
# linear proj
|
257 |
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
258 |
# dropout
|
259 |
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
260 |
|
|
|
296 |
"""
|
297 |
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
298 |
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
|
|
299 |
Args:
|
300 |
scheduler (`SchedulerMixin`):
|
301 |
The scheduler to get timesteps from.
|
|
|
308 |
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
309 |
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
310 |
must be `None`.
|
|
|
311 |
Returns:
|
312 |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
313 |
second element is the number of inference steps.
|
|
|
328 |
return timesteps, num_inference_steps
|
329 |
|
330 |
|
331 |
+
class StableDiffusionPipeline(
|
332 |
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
|
333 |
):
|
334 |
r"""
|
335 |
Pipeline for text-to-image generation using Stable Diffusion.
|
|
|
336 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
337 |
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
|
|
338 |
The pipeline also inherits the following loading methods:
|
339 |
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
340 |
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
341 |
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
342 |
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
343 |
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
|
|
344 |
Args:
|
345 |
vae ([`AutoencoderKL`]):
|
346 |
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
|
|
533 |
):
|
534 |
r"""
|
535 |
Encodes the prompt into text encoder hidden states.
|
|
|
536 |
Args:
|
537 |
prompt (`str` or `List[str]`, *optional*):
|
538 |
prompt to be encoded
|
|
|
877 |
|
878 |
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
879 |
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
|
|
880 |
The suffixes after the scaling factors represent the stages where they are being applied.
|
|
|
881 |
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
882 |
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
|
|
883 |
Args:
|
884 |
s1 (`float`):
|
885 |
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
|
|
903 |
"""
|
904 |
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
905 |
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
|
|
906 |
<Tip warning={true}>
|
|
|
907 |
This API is 🧪 experimental.
|
|
|
908 |
</Tip>
|
|
|
909 |
Args:
|
910 |
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
911 |
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
|
|
929 |
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
|
930 |
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
931 |
"""Disable QKV projection fusion if enabled.
|
|
|
932 |
<Tip warning={true}>
|
|
|
933 |
This API is 🧪 experimental.
|
|
|
934 |
</Tip>
|
|
|
935 |
Args:
|
936 |
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
937 |
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
|
|
938 |
"""
|
939 |
if unet:
|
940 |
if not self.fusing_unet:
|
|
|
954 |
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
955 |
"""
|
956 |
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
|
|
957 |
Args:
|
958 |
timesteps (`torch.Tensor`):
|
959 |
generate embedding vectors at these timesteps
|
|
|
961 |
dimension of the embeddings to generate
|
962 |
dtype:
|
963 |
data type of the generated embeddings
|
|
|
964 |
Returns:
|
965 |
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
966 |
"""
|
|
|
1106 |
):
|
1107 |
r"""
|
1108 |
The call function to the pipeline for generation.
|
|
|
1109 |
Args:
|
1110 |
prompt (`str` or `List[str]`, *optional*):
|
1111 |
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
|
|
1172 |
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
1173 |
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
1174 |
`._callback_tensor_inputs` attribute of your pipeline class.
|
|
|
1175 |
Examples:
|
|
|
1176 |
Returns:
|
1177 |
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1178 |
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|