sohojoe commited on
Commit
7d22699
1 Parent(s): b6c665f

runs on mac but generation is wrong

Browse files
Files changed (9) hide show
  1. app.py +263 -0
  2. pipeline.py +486 -0
  3. pup1.jpg +0 -0
  4. pup2.jpg +0 -0
  5. pup3.jpg +0 -0
  6. pup4.jpeg +0 -0
  7. pup5.jpg +0 -0
  8. requirements.txt +8 -0
  9. test-platform.py +8 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ from diffusers import StableDiffusionPipeline, StableDiffusionImageVariationPipeline, DiffusionPipeline
6
+ import numpy as np
7
+ import pandas as pd
8
+ import math
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
+
11
+ # model_id = "stabilityai/stable-diffusion-2-1-base"
12
+ # text_model_id = "CompVis/stable-diffusion-v-1-4-original"
13
+ # text_model_id = "CompVis/stable-diffusion-v1-4"
14
+ text_model_id = "runwayml/stable-diffusion-v1-5"
15
+ # text_model_id = "stabilityai/stable-diffusion-2-1-base"
16
+ model_id = "lambdalabs/sd-image-variations-diffusers"
17
+ clip_model_id = "openai/clip-vit-large-patch14-336"
18
+
19
+ max_tabs = 5
20
+ input_images = [None for i in range(max_tabs)]
21
+ input_prompts = [None for i in range(max_tabs)]
22
+ embedding_plots = [None for i in range(max_tabs)]
23
+ # global embedding_base64s
24
+ embedding_base64s = [None for i in range(max_tabs)]
25
+ # embedding_base64s = gr.State(value=[None for i in range(max_tabs)])
26
+
27
+
28
+ def image_to_embedding(input_im):
29
+ tform = transforms.Compose([
30
+ transforms.ToTensor(),
31
+ transforms.Resize(
32
+ (224, 224),
33
+ interpolation=transforms.InterpolationMode.BICUBIC,
34
+ antialias=False,
35
+ ),
36
+ transforms.Normalize(
37
+ [0.48145466, 0.4578275, 0.40821073],
38
+ [0.26862954, 0.26130258, 0.27577711]),
39
+ ])
40
+
41
+ inp = tform(input_im).to(device)
42
+ dtype = next(pipe.image_encoder.parameters()).dtype
43
+ image = inp.tile(1, 1, 1, 1).to(device=device, dtype=dtype)
44
+ image_embeddings = pipe.image_encoder(image).image_embeds
45
+ image_embeddings = image_embeddings[0]
46
+ image_embeddings_np = image_embeddings.cpu().detach().numpy()
47
+ return image_embeddings_np
48
+
49
+ def prompt_to_embedding(prompt):
50
+ # inputs = processor(prompt, images=imgs, return_tensors="pt", padding=True)
51
+ inputs = processor(prompt, return_tensors="pt", padding='max_length', max_length=77)
52
+ # labels = torch.tensor(labels)
53
+ # prompt_tokens = inputs.input_ids[0]
54
+ prompt_tokens = inputs.input_ids
55
+ # image = inputs.pixel_values
56
+ with torch.no_grad():
57
+ prompt_embededdings = model.get_text_features(prompt_tokens.to(device))
58
+ prompt_embededdings = prompt_embededdings[0].cpu().detach().numpy()
59
+ return prompt_embededdings
60
+
61
+ def embedding_to_image(embeddings):
62
+ size = math.ceil(math.sqrt(embeddings.shape[0]))
63
+ image_embeddings_square = np.pad(embeddings, (0, size**2 - embeddings.shape[0]), 'constant')
64
+ image_embeddings_square.resize(size,size)
65
+ embedding_image = Image.fromarray(image_embeddings_square, mode="L")
66
+ return embedding_image
67
+
68
+ def embedding_to_base64(embeddings):
69
+ import base64
70
+ # ensure float16
71
+ embeddings = embeddings.astype(np.float16)
72
+ embeddings_b64 = base64.urlsafe_b64encode(embeddings).decode()
73
+ return embeddings_b64
74
+
75
+ def base64_to_embedding(embeddings_b64):
76
+ import base64
77
+ embeddings = base64.urlsafe_b64decode(embeddings_b64)
78
+ embeddings = np.frombuffer(embeddings, dtype=np.float16)
79
+ # embeddings = torch.tensor(embeddings)
80
+ return embeddings
81
+
82
+ def main(
83
+ # input_im,
84
+ embeddings,
85
+ scale=3.0,
86
+ n_samples=4,
87
+ steps=25,
88
+ seed=0
89
+ ):
90
+
91
+ if seed == None:
92
+ seed = np.random.randint(2147483647)
93
+ # generator = torch.Generator(device=device).manual_seed(int(seed))
94
+ generator = torch.Generator().manual_seed(int(seed)) # use cpu as does not work on mps
95
+
96
+ embeddings = base64_to_embedding(embeddings)
97
+ embeddings = torch.tensor(embeddings).to(device)
98
+
99
+ images_list = pipe(
100
+ # inp.tile(n_samples, 1, 1, 1),
101
+ # [embeddings * n_samples],
102
+ embeddings,
103
+ guidance_scale=scale,
104
+ num_inference_steps=steps,
105
+ generator=generator,
106
+ )
107
+
108
+ images = []
109
+ for i, image in enumerate(images_list["images"]):
110
+ images.append(image)
111
+ # images.append(embedding_image)
112
+ return images
113
+
114
+ def on_image_load_update_embeddings(image_data):
115
+ # image to embeddings
116
+ if image_data is None:
117
+ embeddings = prompt_to_embedding('')
118
+ embeddings_b64 = embedding_to_base64(embeddings)
119
+ return gr.Text.update(embeddings_b64)
120
+ embeddings = image_to_embedding(image_data)
121
+ embeddings_b64 = embedding_to_base64(embeddings)
122
+ return gr.Text.update(embeddings_b64)
123
+
124
+ def on_prompt_change_update_embeddings(prompt):
125
+ # prompt to embeddings
126
+ if prompt is None or prompt == "":
127
+ embeddings = prompt_to_embedding('')
128
+ embeddings_b64 = embedding_to_base64(embeddings)
129
+ return gr.Text.update(embedding_to_base64(embeddings))
130
+ embeddings = prompt_to_embedding(prompt)
131
+ embeddings_b64 = embedding_to_base64(embeddings)
132
+ return gr.Text.update(embeddings_b64)
133
+
134
+ # def on_embeddings_changed_update_average_embeddings(last_embedding_base64):
135
+ # def on_embeddings_changed_update_average_embeddings(embedding_base64s):
136
+ def on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_base64, idx):
137
+ # global embedding_base64s
138
+ final_embedding = None
139
+ num_embeddings = 0
140
+ embedding_base64s_state[idx] = embedding_base64
141
+ # for textbox in embedding_base64s:
142
+ # embedding_base64 = textbox.value
143
+ for embedding_base64 in embedding_base64s_state:
144
+ if embedding_base64 is None or embedding_base64 == "":
145
+ continue
146
+ embedding = base64_to_embedding(embedding_base64)
147
+ if final_embedding is None:
148
+ final_embedding = embedding
149
+ else:
150
+ final_embedding = final_embedding + embedding
151
+ num_embeddings += 1
152
+ if final_embedding is None:
153
+ embeddings = prompt_to_embedding('')
154
+ embeddings_b64 = embedding_to_base64(embeddings)
155
+ return gr.Text.update(embeddings_b64)
156
+ final_embedding = final_embedding / num_embeddings
157
+ embeddings_b64 = embedding_to_base64(final_embedding)
158
+ return gr.Text.update(embeddings_b64)
159
+
160
+ def on_embeddings_changed_update_plot(embeddings_b64):
161
+ # plot new embeddings
162
+ if embeddings_b64 is None or embeddings_b64 == "":
163
+ return gr.LinePlot.update()
164
+
165
+ embeddings = base64_to_embedding(embeddings_b64)
166
+ data = pd.DataFrame({
167
+ 'embedding': embeddings,
168
+ 'index': [n for n in range(len(embeddings))]})
169
+ return gr.LinePlot.update(data,
170
+ x="index",
171
+ y="embedding",
172
+ # color="country",
173
+ title="Embeddings",
174
+ # stroke_dash="cluster",
175
+ # x_lim=[1950, 2010],
176
+ tooltip=['index', 'embedding'],
177
+ # stroke_dash_legend_title="Country Cluster",
178
+ # height=300,
179
+ width=embeddings.shape[0])
180
+
181
+
182
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")
183
+ pipe = StableDiffusionPipeline.from_pretrained(
184
+ model_id,
185
+ custom_pipeline="pipeline.py",
186
+ torch_dtype=torch.float16,
187
+ # , revision="fp16",
188
+ requires_safety_checker = False, safety_checker=None,
189
+ text_encoder = CLIPTextModel,
190
+ tokenizer = CLIPTokenizer,
191
+ )
192
+ pipe = pipe.to(device)
193
+
194
+ from transformers import AutoProcessor, AutoModel
195
+ processor = AutoProcessor.from_pretrained(clip_model_id)
196
+ model = AutoModel.from_pretrained(clip_model_id)
197
+ model = model.to(device)
198
+
199
+ examples = [
200
+ ["frog.png", 3, 1, 25, 0],
201
+ ["img0.jpg", 3, 1, 25, 0],
202
+ ["img1.jpg", 3, 1, 25, 0],
203
+ ["img2.jpg", 3, 1, 25, 0],
204
+ ["img3.jpg", 3, 1, 25, 0],
205
+ ]
206
+
207
+
208
+ with gr.Blocks() as demo:
209
+ with gr.Row():
210
+ for i in range(max_tabs):
211
+ with gr.Tab(f"Input {i}"):
212
+ with gr.Row():
213
+ with gr.Column(scale=1, min_width=240):
214
+ input_images[i] = gr.Image()
215
+ with gr.Column(scale=3, min_width=600):
216
+ embedding_plots[i] = gr.LinePlot(show_label=False).style(container=False)
217
+ # input_image.change(on_image_load, inputs= [input_image, plot])
218
+ with gr.Row():
219
+ with gr.Column(scale=1, min_width=240):
220
+ input_prompts[i] = gr.Textbox()
221
+ with gr.Column(scale=3, min_width=600):
222
+ with gr.Accordion("Embeddings", open=False):
223
+ embedding_base64s[i] = gr.Textbox(show_label=False)
224
+
225
+ with gr.Row():
226
+ average_embedding_plot = gr.LinePlot(show_label=False).style(container=False)
227
+ with gr.Row():
228
+ average_embedding_base64 = gr.Textbox(show_label=False)
229
+
230
+ with gr.Row():
231
+ with gr.Column(scale=1, min_width=200):
232
+ scale = gr.Slider(0, 25, value=3, step=1, label="Guidance scale")
233
+ with gr.Column(scale=1, min_width=200):
234
+ n_samples = gr.Slider(1, 4, value=1, step=1, label="Number images")
235
+ with gr.Column(scale=1, min_width=200):
236
+ steps = gr.Slider(5, 50, value=25, step=5, label="Steps")
237
+ with gr.Column(scale=1, min_width=200):
238
+ seed = gr.Number(None, label="Seed", precision=0)
239
+ with gr.Row():
240
+ submit = gr.Button("Submit")
241
+ with gr.Row():
242
+ output = gr.Gallery(label="Generated variations")
243
+
244
+ embedding_base64s_state = gr.State(value=[None for i in range(max_tabs)])
245
+ for i in range(max_tabs):
246
+ input_images[i].change(on_image_load_update_embeddings, input_images[i], [embedding_base64s[i]])
247
+ input_prompts[i].submit(on_prompt_change_update_embeddings, input_prompts[i], [embedding_base64s[i]])
248
+ embedding_base64s[i].change(on_embeddings_changed_update_plot, embedding_base64s[i], [embedding_plots[i]])
249
+ # embedding_plots[i].change(on_plot_changed, embedding_base64s[i], average_embedding_base64)
250
+ # embedding_plots[i].change(on_embeddings_changed_update_average_embeddings, embedding_base64s[i], average_embedding_base64)
251
+ idx_state = gr.State(value=i)
252
+ embedding_base64s[i].change(on_embeddings_changed_update_average_embeddings, [embedding_base64s_state, embedding_base64s[i], idx_state], average_embedding_base64)
253
+
254
+ average_embedding_base64.change(on_embeddings_changed_update_plot, average_embedding_base64, average_embedding_plot)
255
+
256
+ # submit.click(main, inputs= [embedding_base64s[0], scale, n_samples, steps, seed], outputs=output)
257
+ submit.click(main, inputs= [average_embedding_base64, scale, n_samples, steps, seed], outputs=output)
258
+ output.style(grid=2)
259
+
260
+
261
+
262
+ if __name__ == "__main__":
263
+ demo.launch()
pipeline.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Callable, List, Optional, Union
17
+
18
+ import torch
19
+
20
+ import PIL
21
+ from diffusers.utils import is_accelerate_available
22
+ from packaging import version
23
+ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
24
+ from transformers import CLIPTextModel, CLIPTokenizer
25
+
26
+ # from ...configuration_utils import FrozenDict
27
+ # from ...models import AutoencoderKL, UNet2DConditionModel
28
+ # from ...pipeline_utils import DiffusionPipeline
29
+ # from ...schedulers import (
30
+ # DDIMScheduler,
31
+ # DPMSolverMultistepScheduler,
32
+ # EulerAncestralDiscreteScheduler,
33
+ # EulerDiscreteScheduler,
34
+ # LMSDiscreteScheduler,
35
+ # PNDMScheduler,
36
+ # )
37
+ # from ...utils import deprecate, logging
38
+ # from . import StableDiffusionPipelineOutput
39
+ # from .safety_checker import StableDiffusionSafetyChecker
40
+ from diffusers.configuration_utils import FrozenDict
41
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
42
+ from diffusers.pipeline_utils import DiffusionPipeline
43
+ from diffusers.schedulers import (
44
+ DDIMScheduler,
45
+ DPMSolverMultistepScheduler,
46
+ EulerAncestralDiscreteScheduler,
47
+ EulerDiscreteScheduler,
48
+ LMSDiscreteScheduler,
49
+ PNDMScheduler,
50
+ )
51
+ from diffusers.utils import deprecate, logging
52
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ class StableDiffusionImageTextVariationPipeline(DiffusionPipeline):
59
+ r"""
60
+ Pipeline to generate variations from an input image using Stable Diffusion.
61
+
62
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
63
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
64
+
65
+ Args:
66
+ vae ([`AutoencoderKL`]):
67
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
68
+ image_encoder ([`CLIPVisionModelWithProjection`]):
69
+ Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
70
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
71
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
72
+ text_encoder ([`CLIPTextModel`]):
73
+ Frozen text-encoder. Stable Diffusion uses the text portion of
74
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
75
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
76
+ tokenizer (`CLIPTokenizer`):
77
+ Tokenizer of class
78
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
79
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
80
+ scheduler ([`SchedulerMixin`]):
81
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
82
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
83
+ safety_checker ([`StableDiffusionSafetyChecker`]):
84
+ Classification module that estimates whether generated images could be considered offensive or harmful.
85
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
86
+ feature_extractor ([`CLIPFeatureExtractor`]):
87
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
88
+ """
89
+ _optional_components = ["safety_checker"]
90
+
91
+ def __init__(
92
+ self,
93
+ vae: AutoencoderKL,
94
+ image_encoder: CLIPVisionModelWithProjection,
95
+ text_encoder: CLIPTextModel,
96
+ tokenizer: CLIPTokenizer,
97
+ unet: UNet2DConditionModel,
98
+ scheduler: Union[
99
+ DDIMScheduler,
100
+ PNDMScheduler,
101
+ LMSDiscreteScheduler,
102
+ EulerDiscreteScheduler,
103
+ EulerAncestralDiscreteScheduler,
104
+ DPMSolverMultistepScheduler,
105
+ ],
106
+ safety_checker: StableDiffusionSafetyChecker,
107
+ feature_extractor: CLIPFeatureExtractor,
108
+ requires_safety_checker: bool = True,
109
+ ):
110
+ super().__init__()
111
+
112
+ if safety_checker is None and requires_safety_checker:
113
+ logger.warn(
114
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
115
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
116
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
117
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
118
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
119
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
120
+ )
121
+
122
+ if safety_checker is not None and feature_extractor is None:
123
+ raise ValueError(
124
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
125
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
126
+ )
127
+
128
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
129
+ version.parse(unet.config._diffusers_version).base_version
130
+ ) < version.parse("0.9.0.dev0")
131
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
132
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
133
+ deprecation_message = (
134
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
135
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
136
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
137
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
138
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
139
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
140
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
141
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
142
+ " the `unet/config.json` file"
143
+ )
144
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
145
+ new_config = dict(unet.config)
146
+ new_config["sample_size"] = 64
147
+ unet._internal_dict = FrozenDict(new_config)
148
+
149
+ self.register_modules(
150
+ vae=vae,
151
+ image_encoder=image_encoder,
152
+ text_encoder=text_encoder,
153
+ tokenizer=tokenizer,
154
+ unet=unet,
155
+ scheduler=scheduler,
156
+ safety_checker=safety_checker,
157
+ feature_extractor=feature_extractor,
158
+ )
159
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
160
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
161
+
162
+ def enable_sequential_cpu_offload(self, gpu_id=0):
163
+ r"""
164
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
165
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
166
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
167
+ """
168
+ if is_accelerate_available():
169
+ from accelerate import cpu_offload
170
+ else:
171
+ raise ImportError("Please install accelerate via `pip install accelerate`")
172
+
173
+ device = torch.device(f"cuda:{gpu_id}")
174
+
175
+ for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]:
176
+ if cpu_offloaded_model is not None:
177
+ cpu_offload(cpu_offloaded_model, device)
178
+
179
+ @property
180
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
181
+ def _execution_device(self):
182
+ r"""
183
+ Returns the device on which the pipeline's models will be executed. After calling
184
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
185
+ hooks.
186
+ """
187
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
188
+ return self.device
189
+ for module in self.unet.modules():
190
+ if (
191
+ hasattr(module, "_hf_hook")
192
+ and hasattr(module._hf_hook, "execution_device")
193
+ and module._hf_hook.execution_device is not None
194
+ ):
195
+ return torch.device(module._hf_hook.execution_device)
196
+ return self.device
197
+
198
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
199
+ dtype = next(self.image_encoder.parameters()).dtype
200
+
201
+ if not isinstance(image, torch.Tensor):
202
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
203
+
204
+ image = image.to(device=device, dtype=dtype)
205
+ image_embeddings = self.image_encoder(image).image_embeds
206
+ image_embeddings = image_embeddings.unsqueeze(1)
207
+
208
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
209
+ bs_embed, seq_len, _ = image_embeddings.shape
210
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
211
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
212
+
213
+ if do_classifier_free_guidance:
214
+ uncond_embeddings = torch.zeros_like(image_embeddings)
215
+
216
+ # For classifier free guidance, we need to do two forward passes.
217
+ # Here we concatenate the unconditional and text embeddings into a single batch
218
+ # to avoid doing two forward passes
219
+ image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
220
+
221
+ return image_embeddings
222
+
223
+ def _prepare_embeddings(self, embeddings, device, num_images_per_prompt, do_classifier_free_guidance):
224
+ dtype = next(self.image_encoder.parameters()).dtype
225
+
226
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
227
+ # bs_embed, seq_len, _ = embeddings.shape
228
+ # bs_embed = len(embeddings)
229
+ # seq_len = embeddings[0].shape[0]
230
+ # embeddings = embeddings.repeat(1, num_images_per_prompt, 1)
231
+ embeddings = embeddings.repeat(1, 1, num_images_per_prompt)
232
+ # embeddings = embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
233
+
234
+ if do_classifier_free_guidance:
235
+ uncond_embeddings = torch.zeros_like(embeddings)
236
+
237
+ # For classifier free guidance, we need to do two forward passes.
238
+ # Here we concatenate the unconditional and text embeddings into a single batch
239
+ # to avoid doing two forward passes
240
+ embeddings = torch.cat([uncond_embeddings, embeddings])
241
+
242
+ return embeddings
243
+
244
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
245
+ def run_safety_checker(self, image, device, dtype):
246
+ if self.safety_checker is not None:
247
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
248
+ image, has_nsfw_concept = self.safety_checker(
249
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
250
+ )
251
+ else:
252
+ has_nsfw_concept = None
253
+ return image, has_nsfw_concept
254
+
255
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
256
+ def decode_latents(self, latents):
257
+ latents = 1 / 0.18215 * latents
258
+ image = self.vae.decode(latents).sample
259
+ image = (image / 2 + 0.5).clamp(0, 1)
260
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
261
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
262
+ return image
263
+
264
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
265
+ def prepare_extra_step_kwargs(self, generator, eta):
266
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
267
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
268
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
269
+ # and should be between [0, 1]
270
+
271
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
272
+ extra_step_kwargs = {}
273
+ if accepts_eta:
274
+ extra_step_kwargs["eta"] = eta
275
+
276
+ # check if the scheduler accepts generator
277
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
278
+ if accepts_generator:
279
+ extra_step_kwargs["generator"] = generator
280
+ return extra_step_kwargs
281
+
282
+ def check_inputs(self, image, height, width, callback_steps):
283
+ if (
284
+ not isinstance(image, torch.Tensor)
285
+ and not isinstance(image, PIL.Image.Image)
286
+ and not isinstance(image, list)
287
+ ):
288
+ raise ValueError(
289
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
290
+ f" {type(image)}"
291
+ )
292
+
293
+ if height % 8 != 0 or width % 8 != 0:
294
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
295
+
296
+ if (callback_steps is None) or (
297
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
298
+ ):
299
+ raise ValueError(
300
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
301
+ f" {type(callback_steps)}."
302
+ )
303
+
304
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
305
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
306
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
307
+ if isinstance(generator, list) and len(generator) != batch_size:
308
+ raise ValueError(
309
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
310
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
311
+ )
312
+
313
+ if latents is None:
314
+ rand_device = "cpu" if device.type == "mps" else device
315
+
316
+ if isinstance(generator, list):
317
+ shape = (1,) + shape[1:]
318
+ latents = [
319
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
320
+ for i in range(batch_size)
321
+ ]
322
+ latents = torch.cat(latents, dim=0).to(device)
323
+ else:
324
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
325
+ else:
326
+ if latents.shape != shape:
327
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
328
+ latents = latents.to(device)
329
+
330
+ # scale the initial noise by the standard deviation required by the scheduler
331
+ latents = latents * self.scheduler.init_noise_sigma
332
+ return latents
333
+
334
+ @torch.no_grad()
335
+ def __call__(
336
+ self,
337
+ # image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
338
+ embeddings: torch.FloatTensor,
339
+ height: Optional[int] = None,
340
+ width: Optional[int] = None,
341
+ num_inference_steps: int = 50,
342
+ guidance_scale: float = 7.5,
343
+ num_images_per_prompt: Optional[int] = 1,
344
+ eta: float = 0.0,
345
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
346
+ latents: Optional[torch.FloatTensor] = None,
347
+ output_type: Optional[str] = "pil",
348
+ return_dict: bool = True,
349
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
350
+ callback_steps: Optional[int] = 1,
351
+ ):
352
+ r"""
353
+ Function invoked when calling the pipeline for generation.
354
+
355
+ Args:
356
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
357
+ The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
358
+ configuration of
359
+ [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
360
+ `CLIPFeatureExtractor`
361
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
362
+ The height in pixels of the generated image.
363
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
364
+ The width in pixels of the generated image.
365
+ num_inference_steps (`int`, *optional*, defaults to 50):
366
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
367
+ expense of slower inference.
368
+ guidance_scale (`float`, *optional*, defaults to 7.5):
369
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
370
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
371
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
372
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
373
+ usually at the expense of lower image quality.
374
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
375
+ The number of images to generate per prompt.
376
+ eta (`float`, *optional*, defaults to 0.0):
377
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
378
+ [`schedulers.DDIMScheduler`], will be ignored for others.
379
+ generator (`torch.Generator`, *optional*):
380
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
381
+ to make generation deterministic.
382
+ latents (`torch.FloatTensor`, *optional*):
383
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
384
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
385
+ tensor will ge generated by sampling using the supplied random `generator`.
386
+ output_type (`str`, *optional*, defaults to `"pil"`):
387
+ The output format of the generate image. Choose between
388
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
389
+ return_dict (`bool`, *optional*, defaults to `True`):
390
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
391
+ plain tuple.
392
+ callback (`Callable`, *optional*):
393
+ A function that will be called every `callback_steps` steps during inference. The function will be
394
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
395
+ callback_steps (`int`, *optional*, defaults to 1):
396
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
397
+ called at every step.
398
+
399
+ Returns:
400
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
401
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
402
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
403
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
404
+ (nsfw) content, according to the `safety_checker`.
405
+ """
406
+ # 0. Default height and width to unet
407
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
408
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
409
+
410
+ # 1. Check inputs. Raise error if not correct
411
+ # self.check_inputs(image, height, width, callback_steps)
412
+
413
+ # 2. Define call parameters
414
+ if isinstance(embeddings, list):
415
+ batch_size = len(embeddings)
416
+ else:
417
+ batch_size = 1
418
+ device = self._execution_device
419
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
420
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
421
+ # corresponds to doing no classifier free guidance.
422
+ do_classifier_free_guidance = guidance_scale > 1.0
423
+
424
+ # 3. Encode input image
425
+ embeddings = self._prepare_embeddings(embeddings, device, num_images_per_prompt, do_classifier_free_guidance)
426
+ embeddings = embeddings.to(device)
427
+
428
+ # 4. Prepare timesteps
429
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
430
+ timesteps = self.scheduler.timesteps
431
+
432
+ # 5. Prepare latent variables
433
+ num_channels_latents = self.unet.in_channels
434
+ latents = self.prepare_latents(
435
+ batch_size * num_images_per_prompt,
436
+ num_channels_latents,
437
+ height,
438
+ width,
439
+ embeddings.dtype,
440
+ device,
441
+ generator,
442
+ latents,
443
+ )
444
+
445
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
446
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
447
+
448
+ # 7. Denoising loop
449
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
450
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
451
+ for i, t in enumerate(timesteps):
452
+ # expand the latents if we are doing classifier free guidance
453
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
454
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
455
+
456
+ # predict the noise residual
457
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=embeddings).sample
458
+
459
+ # perform guidance
460
+ if do_classifier_free_guidance:
461
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
462
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
463
+
464
+ # compute the previous noisy sample x_t -> x_t-1
465
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
466
+
467
+ # call the callback, if provided
468
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
469
+ progress_bar.update()
470
+ if callback is not None and i % callback_steps == 0:
471
+ callback(i, t, latents)
472
+
473
+ # 8. Post-processing
474
+ image = self.decode_latents(latents)
475
+
476
+ # 9. Run safety checker
477
+ image, has_nsfw_concept = self.run_safety_checker(image, device, embeddings.dtype)
478
+
479
+ # 10. Convert to PIL
480
+ if output_type == "pil":
481
+ image = self.numpy_to_pil(image)
482
+
483
+ if not return_dict:
484
+ return (image, has_nsfw_concept)
485
+
486
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
pup1.jpg ADDED
pup2.jpg ADDED
pup3.jpg ADDED
pup4.jpeg ADDED
pup5.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ transformers
5
+ diffusers
6
+ # ftfy
7
+ gradio
8
+ accelerate
test-platform.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # set device to mps if avaliable, cude if avaliable, cpu otherwise
4
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")
5
+ print (device)
6
+
7
+ x = torch.zeros(1, device=device)
8
+ print (str(x))