DamarJati commited on
Commit
30bbb9e
1 Parent(s): a126ec1

Upload 9 files

Browse files
src/flux/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import version as __version__ # type: ignore
3
+ from ._version import version_tuple
4
+ except ImportError:
5
+ __version__ = "unknown (no version information available)"
6
+ version_tuple = (0, 0, "unknown", "noinfo")
7
+
8
+ from pathlib import Path
9
+
10
+ PACKAGE = __package__.replace("_", "-")
11
+ PACKAGE_ROOT = Path(__file__).parent
src/flux/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
src/flux/api.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+ API_ENDPOINT = "https://api.bfl.ml"
10
+
11
+
12
+ class ApiException(Exception):
13
+ def __init__(self, status_code: int, detail: str | list[dict] | None = None):
14
+ super().__init__()
15
+ self.detail = detail
16
+ self.status_code = status_code
17
+
18
+ def __str__(self) -> str:
19
+ return self.__repr__()
20
+
21
+ def __repr__(self) -> str:
22
+ if self.detail is None:
23
+ message = None
24
+ elif isinstance(self.detail, str):
25
+ message = self.detail
26
+ else:
27
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
28
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
29
+
30
+
31
+ class ImageRequest:
32
+ def __init__(
33
+ self,
34
+ prompt: str,
35
+ width: int = 1024,
36
+ height: int = 1024,
37
+ name: str = "flux.1-pro",
38
+ num_steps: int = 50,
39
+ prompt_upsampling: bool = False,
40
+ seed: int | None = None,
41
+ validate: bool = True,
42
+ launch: bool = True,
43
+ api_key: str | None = None,
44
+ ):
45
+ """
46
+ Manages an image generation request to the API.
47
+
48
+ Args:
49
+ prompt: Prompt to sample
50
+ width: Width of the image in pixel
51
+ height: Height of the image in pixel
52
+ name: Name of the model
53
+ num_steps: Number of network evaluations
54
+ prompt_upsampling: Use prompt upsampling
55
+ seed: Fix the generation seed
56
+ validate: Run input validation
57
+ launch: Directly launches request
58
+ api_key: Your API key if not provided by the environment
59
+
60
+ Raises:
61
+ ValueError: For invalid input
62
+ ApiException: For errors raised from the API
63
+ """
64
+ if validate:
65
+ if name not in ["flux.1-pro"]:
66
+ raise ValueError(f"Invalid model {name}")
67
+ elif width % 32 != 0:
68
+ raise ValueError(f"width must be divisible by 32, got {width}")
69
+ elif not (256 <= width <= 1440):
70
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
71
+ elif height % 32 != 0:
72
+ raise ValueError(f"height must be divisible by 32, got {height}")
73
+ elif not (256 <= height <= 1440):
74
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
75
+ elif not (1 <= num_steps <= 50):
76
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
77
+
78
+ self.request_json = {
79
+ "prompt": prompt,
80
+ "width": width,
81
+ "height": height,
82
+ "variant": name,
83
+ "steps": num_steps,
84
+ "prompt_upsampling": prompt_upsampling,
85
+ }
86
+ if seed is not None:
87
+ self.request_json["seed"] = seed
88
+
89
+ self.request_id: str | None = None
90
+ self.result: dict | None = None
91
+ self._image_bytes: bytes | None = None
92
+ self._url: str | None = None
93
+ if api_key is None:
94
+ self.api_key = os.environ.get("BFL_API_KEY")
95
+ else:
96
+ self.api_key = api_key
97
+
98
+ if launch:
99
+ self.request()
100
+
101
+ def request(self):
102
+ """
103
+ Request to generate the image.
104
+ """
105
+ if self.request_id is not None:
106
+ return
107
+ response = requests.post(
108
+ f"{API_ENDPOINT}/v1/image",
109
+ headers={
110
+ "accept": "application/json",
111
+ "x-key": self.api_key,
112
+ "Content-Type": "application/json",
113
+ },
114
+ json=self.request_json,
115
+ )
116
+ result = response.json()
117
+ if response.status_code != 200:
118
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
119
+ self.request_id = response.json()["id"]
120
+
121
+ def retrieve(self) -> dict:
122
+ """
123
+ Wait for the generation to finish and retrieve response.
124
+ """
125
+ if self.request_id is None:
126
+ self.request()
127
+ while self.result is None:
128
+ response = requests.get(
129
+ f"{API_ENDPOINT}/v1/get_result",
130
+ headers={
131
+ "accept": "application/json",
132
+ "x-key": self.api_key,
133
+ },
134
+ params={
135
+ "id": self.request_id,
136
+ },
137
+ )
138
+ result = response.json()
139
+ if "status" not in result:
140
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
141
+ elif result["status"] == "Ready":
142
+ self.result = result["result"]
143
+ elif result["status"] == "Pending":
144
+ time.sleep(0.5)
145
+ else:
146
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
147
+ return self.result
148
+
149
+ @property
150
+ def bytes(self) -> bytes:
151
+ """
152
+ Generated image as bytes.
153
+ """
154
+ if self._image_bytes is None:
155
+ response = requests.get(self.url)
156
+ if response.status_code == 200:
157
+ self._image_bytes = response.content
158
+ else:
159
+ raise ApiException(status_code=response.status_code)
160
+ return self._image_bytes
161
+
162
+ @property
163
+ def url(self) -> str:
164
+ """
165
+ Public url to retrieve the image from
166
+ """
167
+ if self._url is None:
168
+ result = self.retrieve()
169
+ self._url = result["sample"]
170
+ return self._url
171
+
172
+ @property
173
+ def image(self) -> Image.Image:
174
+ """
175
+ Load the image as a PIL Image
176
+ """
177
+ return Image.open(io.BytesIO(self.bytes))
178
+
179
+ def save(self, path: str):
180
+ """
181
+ Save the generated image to a local path
182
+ """
183
+ suffix = Path(self.url).suffix
184
+ if not path.endswith(suffix):
185
+ path = path + suffix
186
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
187
+ with open(path, "wb") as file:
188
+ file.write(self.bytes)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ from fire import Fire
193
+
194
+ Fire(ImageRequest)
src/flux/cli.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from fire import Fire
10
+ from PIL import ExifTags, Image
11
+
12
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
13
+ from flux.util import (configs, embed_watermark, load_ae, load_clip,
14
+ load_flow_model, load_t5)
15
+ from transformers import pipeline
16
+
17
+ NSFW_THRESHOLD = 0.85
18
+
19
+ @dataclass
20
+ class SamplingOptions:
21
+ prompt: str
22
+ width: int
23
+ height: int
24
+ num_steps: int
25
+ guidance: float
26
+ seed: int | None
27
+
28
+
29
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
30
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
31
+ usage = (
32
+ "Usage: Either write your prompt directly, leave this field empty "
33
+ "to repeat the prompt or write a command starting with a slash:\n"
34
+ "- '/w <width>' will set the width of the generated image\n"
35
+ "- '/h <height>' will set the height of the generated image\n"
36
+ "- '/s <seed>' sets the next seed\n"
37
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
38
+ "- '/n <steps>' sets the number of steps\n"
39
+ "- '/q' to quit"
40
+ )
41
+
42
+ while (prompt := input(user_question)).startswith("/"):
43
+ if prompt.startswith("/w"):
44
+ if prompt.count(" ") != 1:
45
+ print(f"Got invalid command '{prompt}'\n{usage}")
46
+ continue
47
+ _, width = prompt.split()
48
+ options.width = 16 * (int(width) // 16)
49
+ print(
50
+ f"Setting resolution to {options.width} x {options.height} "
51
+ f"({options.height *options.width/1e6:.2f}MP)"
52
+ )
53
+ elif prompt.startswith("/h"):
54
+ if prompt.count(" ") != 1:
55
+ print(f"Got invalid command '{prompt}'\n{usage}")
56
+ continue
57
+ _, height = prompt.split()
58
+ options.height = 16 * (int(height) // 16)
59
+ print(
60
+ f"Setting resolution to {options.width} x {options.height} "
61
+ f"({options.height *options.width/1e6:.2f}MP)"
62
+ )
63
+ elif prompt.startswith("/g"):
64
+ if prompt.count(" ") != 1:
65
+ print(f"Got invalid command '{prompt}'\n{usage}")
66
+ continue
67
+ _, guidance = prompt.split()
68
+ options.guidance = float(guidance)
69
+ print(f"Setting guidance to {options.guidance}")
70
+ elif prompt.startswith("/s"):
71
+ if prompt.count(" ") != 1:
72
+ print(f"Got invalid command '{prompt}'\n{usage}")
73
+ continue
74
+ _, seed = prompt.split()
75
+ options.seed = int(seed)
76
+ print(f"Setting seed to {options.seed}")
77
+ elif prompt.startswith("/n"):
78
+ if prompt.count(" ") != 1:
79
+ print(f"Got invalid command '{prompt}'\n{usage}")
80
+ continue
81
+ _, steps = prompt.split()
82
+ options.num_steps = int(steps)
83
+ print(f"Setting seed to {options.num_steps}")
84
+ elif prompt.startswith("/q"):
85
+ print("Quitting")
86
+ return None
87
+ else:
88
+ if not prompt.startswith("/h"):
89
+ print(f"Got invalid command '{prompt}'\n{usage}")
90
+ print(usage)
91
+ if prompt != "":
92
+ options.prompt = prompt
93
+ return options
94
+
95
+
96
+ @torch.inference_mode()
97
+ def main(
98
+ name: str = "flux-schnell",
99
+ width: int = 1360,
100
+ height: int = 768,
101
+ seed: int | None = None,
102
+ prompt: str = (
103
+ "a photo of a forest with mist swirling around the tree trunks. The word "
104
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
105
+ ),
106
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
107
+ num_steps: int | None = None,
108
+ loop: bool = False,
109
+ guidance: float = 3.5,
110
+ offload: bool = False,
111
+ output_dir: str = "output",
112
+ add_sampling_metadata: bool = True,
113
+ ):
114
+ """
115
+ Sample the flux model. Either interactively (set `--loop`) or run for a
116
+ single image.
117
+
118
+ Args:
119
+ name: Name of the model to load
120
+ height: height of the sample in pixels (should be a multiple of 16)
121
+ width: width of the sample in pixels (should be a multiple of 16)
122
+ seed: Set a seed for sampling
123
+ output_name: where to save the output image, `{idx}` will be replaced
124
+ by the index of the sample
125
+ prompt: Prompt used for sampling
126
+ device: Pytorch device
127
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
128
+ loop: start an interactive session and sample multiple times
129
+ guidance: guidance value used for guidance distillation
130
+ add_sampling_metadata: Add the prompt to the image Exif metadata
131
+ """
132
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
133
+
134
+ if name not in configs:
135
+ available = ", ".join(configs.keys())
136
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
137
+
138
+ torch_device = torch.device(device)
139
+ if num_steps is None:
140
+ num_steps = 4 if name == "flux-schnell" else 50
141
+
142
+ # allow for packing and conversion to latent space
143
+ height = 16 * (height // 16)
144
+ width = 16 * (width // 16)
145
+
146
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
147
+ if not os.path.exists(output_dir):
148
+ os.makedirs(output_dir)
149
+ idx = 0
150
+ else:
151
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)]
152
+ if len(fns) > 0:
153
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
154
+ else:
155
+ idx = 0
156
+
157
+ # init all components
158
+ t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
159
+ clip = load_clip(torch_device)
160
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
161
+ ae = load_ae(name, device="cpu" if offload else torch_device)
162
+
163
+ rng = torch.Generator(device="cpu")
164
+ opts = SamplingOptions(
165
+ prompt=prompt,
166
+ width=width,
167
+ height=height,
168
+ num_steps=num_steps,
169
+ guidance=guidance,
170
+ seed=seed,
171
+ )
172
+
173
+ if loop:
174
+ opts = parse_prompt(opts)
175
+
176
+ while opts is not None:
177
+ if opts.seed is None:
178
+ opts.seed = rng.seed()
179
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
180
+ t0 = time.perf_counter()
181
+
182
+ # prepare input
183
+ x = get_noise(
184
+ 1,
185
+ opts.height,
186
+ opts.width,
187
+ device=torch_device,
188
+ dtype=torch.bfloat16,
189
+ seed=opts.seed,
190
+ )
191
+ opts.seed = None
192
+ if offload:
193
+ ae = ae.cpu()
194
+ torch.cuda.empty_cache()
195
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
196
+ inp = prepare(t5, clip, x, prompt=opts.prompt)
197
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
198
+
199
+ # offload TEs to CPU, load model to gpu
200
+ if offload:
201
+ t5, clip = t5.cpu(), clip.cpu()
202
+ torch.cuda.empty_cache()
203
+ model = model.to(torch_device)
204
+
205
+ # denoise initial noise
206
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
207
+
208
+ # offload model, load autoencoder to gpu
209
+ if offload:
210
+ model.cpu()
211
+ torch.cuda.empty_cache()
212
+ ae.decoder.to(x.device)
213
+
214
+ # decode latents to pixel space
215
+ x = unpack(x.float(), opts.height, opts.width)
216
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
217
+ x = ae.decode(x)
218
+ t1 = time.perf_counter()
219
+
220
+ fn = output_name.format(idx=idx)
221
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
222
+ # bring into PIL format and save
223
+ x = x.clamp(-1, 1)
224
+ x = embed_watermark(x.float())
225
+ x = rearrange(x[0], "c h w -> h w c")
226
+
227
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
228
+ nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
229
+
230
+ if nsfw_score < NSFW_THRESHOLD:
231
+ exif_data = Image.Exif()
232
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
233
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
234
+ exif_data[ExifTags.Base.Model] = name
235
+ if add_sampling_metadata:
236
+ exif_data[ExifTags.Base.ImageDescription] = prompt
237
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
238
+ idx += 1
239
+ else:
240
+ print("Your generated image may contain NSFW content.")
241
+
242
+ if loop:
243
+ print("-" * 80)
244
+ opts = parse_prompt(opts)
245
+ else:
246
+ opts = None
247
+
248
+
249
+ def app():
250
+ Fire(main)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ app()
src/flux/controlnet.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from einops import rearrange
6
+
7
+ from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
8
+ MLPEmbedder, SingleStreamBlock,
9
+ timestep_embedding)
10
+
11
+
12
+ @dataclass
13
+ class FluxParams:
14
+ in_channels: int
15
+ vec_in_dim: int
16
+ context_in_dim: int
17
+ hidden_size: int
18
+ mlp_ratio: float
19
+ num_heads: int
20
+ depth: int
21
+ depth_single_blocks: int
22
+ axes_dim: list[int]
23
+ theta: int
24
+ qkv_bias: bool
25
+ guidance_embed: bool
26
+
27
+ def zero_module(module):
28
+ for p in module.parameters():
29
+ nn.init.zeros_(p)
30
+ return module
31
+
32
+
33
+ class ControlNetFlux(nn.Module):
34
+ """
35
+ Transformer model for flow matching on sequences.
36
+ """
37
+ _supports_gradient_checkpointing = True
38
+
39
+ def __init__(self, params: FluxParams, controlnet_depth=2):
40
+ super().__init__()
41
+
42
+ self.params = params
43
+ self.in_channels = params.in_channels
44
+ self.out_channels = self.in_channels
45
+ if params.hidden_size % params.num_heads != 0:
46
+ raise ValueError(
47
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
48
+ )
49
+ pe_dim = params.hidden_size // params.num_heads
50
+ if sum(params.axes_dim) != pe_dim:
51
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
52
+ self.hidden_size = params.hidden_size
53
+ self.num_heads = params.num_heads
54
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
55
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
56
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
57
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
58
+ self.guidance_in = (
59
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
60
+ )
61
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
62
+
63
+ self.double_blocks = nn.ModuleList(
64
+ [
65
+ DoubleStreamBlock(
66
+ self.hidden_size,
67
+ self.num_heads,
68
+ mlp_ratio=params.mlp_ratio,
69
+ qkv_bias=params.qkv_bias,
70
+ )
71
+ for _ in range(controlnet_depth)
72
+ ]
73
+ )
74
+
75
+ # add ControlNet blocks
76
+ self.controlnet_blocks = nn.ModuleList([])
77
+ for _ in range(controlnet_depth):
78
+ controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
79
+ controlnet_block = zero_module(controlnet_block)
80
+ self.controlnet_blocks.append(controlnet_block)
81
+ self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
82
+ self.gradient_checkpointing = False
83
+ self.input_hint_block = nn.Sequential(
84
+ nn.Conv2d(3, 16, 3, padding=1),
85
+ nn.SiLU(),
86
+ nn.Conv2d(16, 16, 3, padding=1),
87
+ nn.SiLU(),
88
+ nn.Conv2d(16, 16, 3, padding=1, stride=2),
89
+ nn.SiLU(),
90
+ nn.Conv2d(16, 16, 3, padding=1),
91
+ nn.SiLU(),
92
+ nn.Conv2d(16, 16, 3, padding=1, stride=2),
93
+ nn.SiLU(),
94
+ nn.Conv2d(16, 16, 3, padding=1),
95
+ nn.SiLU(),
96
+ nn.Conv2d(16, 16, 3, padding=1, stride=2),
97
+ nn.SiLU(),
98
+ zero_module(nn.Conv2d(16, 16, 3, padding=1))
99
+ )
100
+
101
+ def _set_gradient_checkpointing(self, module, value=False):
102
+ if hasattr(module, "gradient_checkpointing"):
103
+ module.gradient_checkpointing = value
104
+
105
+
106
+ @property
107
+ def attn_processors(self):
108
+ # set recursively
109
+ processors = {}
110
+
111
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
112
+ if hasattr(module, "set_processor"):
113
+ processors[f"{name}.processor"] = module.processor
114
+
115
+ for sub_name, child in module.named_children():
116
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
117
+
118
+ return processors
119
+
120
+ for name, module in self.named_children():
121
+ fn_recursive_add_processors(name, module, processors)
122
+
123
+ return processors
124
+
125
+ def set_attn_processor(self, processor):
126
+ r"""
127
+ Sets the attention processor to use to compute attention.
128
+
129
+ Parameters:
130
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
131
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
132
+ for **all** `Attention` layers.
133
+
134
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
135
+ processor. This is strongly recommended when setting trainable attention processors.
136
+
137
+ """
138
+ count = len(self.attn_processors.keys())
139
+
140
+ if isinstance(processor, dict) and len(processor) != count:
141
+ raise ValueError(
142
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
143
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
144
+ )
145
+
146
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
147
+ if hasattr(module, "set_processor"):
148
+ if not isinstance(processor, dict):
149
+ module.set_processor(processor)
150
+ else:
151
+ module.set_processor(processor.pop(f"{name}.processor"))
152
+
153
+ for sub_name, child in module.named_children():
154
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
155
+
156
+ for name, module in self.named_children():
157
+ fn_recursive_attn_processor(name, module, processor)
158
+
159
+ def forward(
160
+ self,
161
+ img: Tensor,
162
+ img_ids: Tensor,
163
+ controlnet_cond: Tensor,
164
+ txt: Tensor,
165
+ txt_ids: Tensor,
166
+ timesteps: Tensor,
167
+ y: Tensor,
168
+ guidance: Tensor | None = None,
169
+ ) -> Tensor:
170
+ if img.ndim != 3 or txt.ndim != 3:
171
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
172
+
173
+ # running on sequences img
174
+ img = self.img_in(img)
175
+ controlnet_cond = self.input_hint_block(controlnet_cond)
176
+ controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
177
+ controlnet_cond = self.pos_embed_input(controlnet_cond)
178
+ img = img + controlnet_cond
179
+ vec = self.time_in(timestep_embedding(timesteps, 256))
180
+ if self.params.guidance_embed:
181
+ if guidance is None:
182
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
183
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
184
+ vec = vec + self.vector_in(y)
185
+ txt = self.txt_in(txt)
186
+
187
+ ids = torch.cat((txt_ids, img_ids), dim=1)
188
+ pe = self.pe_embedder(ids)
189
+
190
+ block_res_samples = ()
191
+
192
+ for block in self.double_blocks:
193
+ if self.training and self.gradient_checkpointing:
194
+
195
+ def create_custom_forward(module, return_dict=None):
196
+ def custom_forward(*inputs):
197
+ if return_dict is not None:
198
+ return module(*inputs, return_dict=return_dict)
199
+ else:
200
+ return module(*inputs)
201
+
202
+ return custom_forward
203
+
204
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
205
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
206
+ create_custom_forward(block),
207
+ img,
208
+ txt,
209
+ vec,
210
+ pe,
211
+ )
212
+ else:
213
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
214
+
215
+ block_res_samples = block_res_samples + (img,)
216
+
217
+ controlnet_block_res_samples = ()
218
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
219
+ block_res_sample = controlnet_block(block_res_sample)
220
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
221
+
222
+ return controlnet_block_res_samples
src/flux/math.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
+ q, k = apply_rope(q, k, pe)
8
+
9
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10
+ x = rearrange(x, "B H L D -> B L (H D)")
11
+
12
+ return x
13
+
14
+
15
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
16
+ assert dim % 2 == 0
17
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
18
+ omega = 1.0 / (theta**scale)
19
+ out = torch.einsum("...n,d->...nd", pos, omega)
20
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
21
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
22
+ return out.float()
23
+
24
+
25
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
26
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
27
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
28
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
29
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
30
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
src/flux/model.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from einops import rearrange
6
+
7
+ from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
8
+ MLPEmbedder, SingleStreamBlock,
9
+ timestep_embedding)
10
+
11
+
12
+ @dataclass
13
+ class FluxParams:
14
+ in_channels: int
15
+ vec_in_dim: int
16
+ context_in_dim: int
17
+ hidden_size: int
18
+ mlp_ratio: float
19
+ num_heads: int
20
+ depth: int
21
+ depth_single_blocks: int
22
+ axes_dim: list[int]
23
+ theta: int
24
+ qkv_bias: bool
25
+ guidance_embed: bool
26
+
27
+
28
+ class Flux(nn.Module):
29
+ """
30
+ Transformer model for flow matching on sequences.
31
+ """
32
+ _supports_gradient_checkpointing = True
33
+
34
+ def __init__(self, params: FluxParams):
35
+ super().__init__()
36
+
37
+ self.params = params
38
+ self.in_channels = params.in_channels
39
+ self.out_channels = self.in_channels
40
+ if params.hidden_size % params.num_heads != 0:
41
+ raise ValueError(
42
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
43
+ )
44
+ pe_dim = params.hidden_size // params.num_heads
45
+ if sum(params.axes_dim) != pe_dim:
46
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
47
+ self.hidden_size = params.hidden_size
48
+ self.num_heads = params.num_heads
49
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
50
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
51
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
52
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
53
+ self.guidance_in = (
54
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
55
+ )
56
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
57
+
58
+ self.double_blocks = nn.ModuleList(
59
+ [
60
+ DoubleStreamBlock(
61
+ self.hidden_size,
62
+ self.num_heads,
63
+ mlp_ratio=params.mlp_ratio,
64
+ qkv_bias=params.qkv_bias,
65
+ )
66
+ for _ in range(params.depth)
67
+ ]
68
+ )
69
+
70
+ self.single_blocks = nn.ModuleList(
71
+ [
72
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
73
+ for _ in range(params.depth_single_blocks)
74
+ ]
75
+ )
76
+
77
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
78
+ self.gradient_checkpointing = False
79
+
80
+ def _set_gradient_checkpointing(self, module, value=False):
81
+ if hasattr(module, "gradient_checkpointing"):
82
+ module.gradient_checkpointing = value
83
+
84
+ @property
85
+ def attn_processors(self):
86
+ # set recursively
87
+ processors = {}
88
+
89
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
90
+ if hasattr(module, "set_processor"):
91
+ processors[f"{name}.processor"] = module.processor
92
+
93
+ for sub_name, child in module.named_children():
94
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
95
+
96
+ return processors
97
+
98
+ for name, module in self.named_children():
99
+ fn_recursive_add_processors(name, module, processors)
100
+
101
+ return processors
102
+
103
+ def set_attn_processor(self, processor):
104
+ r"""
105
+ Sets the attention processor to use to compute attention.
106
+
107
+ Parameters:
108
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
109
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
110
+ for **all** `Attention` layers.
111
+
112
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
113
+ processor. This is strongly recommended when setting trainable attention processors.
114
+
115
+ """
116
+ count = len(self.attn_processors.keys())
117
+
118
+ if isinstance(processor, dict) and len(processor) != count:
119
+ raise ValueError(
120
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
121
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
122
+ )
123
+
124
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
125
+ if hasattr(module, "set_processor"):
126
+ if not isinstance(processor, dict):
127
+ module.set_processor(processor)
128
+ else:
129
+ module.set_processor(processor.pop(f"{name}.processor"))
130
+
131
+ for sub_name, child in module.named_children():
132
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
133
+
134
+ for name, module in self.named_children():
135
+ fn_recursive_attn_processor(name, module, processor)
136
+
137
+ def forward(
138
+ self,
139
+ img: Tensor,
140
+ img_ids: Tensor,
141
+ txt: Tensor,
142
+ txt_ids: Tensor,
143
+ timesteps: Tensor,
144
+ y: Tensor,
145
+ block_controlnet_hidden_states=None,
146
+ guidance: Tensor | None = None,
147
+ ) -> Tensor:
148
+ if img.ndim != 3 or txt.ndim != 3:
149
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
150
+
151
+ # running on sequences img
152
+ img = self.img_in(img)
153
+ vec = self.time_in(timestep_embedding(timesteps, 256))
154
+ if self.params.guidance_embed:
155
+ if guidance is None:
156
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
157
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
158
+ vec = vec + self.vector_in(y)
159
+ txt = self.txt_in(txt)
160
+
161
+ ids = torch.cat((txt_ids, img_ids), dim=1)
162
+ pe = self.pe_embedder(ids)
163
+ if block_controlnet_hidden_states is not None:
164
+ controlnet_depth = len(block_controlnet_hidden_states)
165
+ for index_block, block in enumerate(self.double_blocks):
166
+ if self.training and self.gradient_checkpointing:
167
+
168
+ def create_custom_forward(module, return_dict=None):
169
+ def custom_forward(*inputs):
170
+ if return_dict is not None:
171
+ return module(*inputs, return_dict=return_dict)
172
+ else:
173
+ return module(*inputs)
174
+
175
+ return custom_forward
176
+
177
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
178
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
179
+ create_custom_forward(block),
180
+ img,
181
+ txt,
182
+ vec,
183
+ pe,
184
+ )
185
+ else:
186
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
187
+ # controlnet residual
188
+ if block_controlnet_hidden_states is not None:
189
+ img = img + block_controlnet_hidden_states[index_block % 2]
190
+
191
+
192
+ img = torch.cat((txt, img), 1)
193
+ for block in self.single_blocks:
194
+ if self.training and self.gradient_checkpointing:
195
+
196
+ def create_custom_forward(module, return_dict=None):
197
+ def custom_forward(*inputs):
198
+ if return_dict is not None:
199
+ return module(*inputs, return_dict=return_dict)
200
+ else:
201
+ return module(*inputs)
202
+
203
+ return custom_forward
204
+
205
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
206
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
207
+ create_custom_forward(block),
208
+ img,
209
+ vec,
210
+ pe,
211
+ )
212
+ else:
213
+ img = block(img, vec=vec, pe=pe)
214
+ img = img[:, txt.shape[1] :, ...]
215
+
216
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
217
+ return img
src/flux/sampling.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor
7
+
8
+ from .model import Flux
9
+ from .modules.conditioner import HFEmbedder
10
+
11
+
12
+ def get_noise(
13
+ num_samples: int,
14
+ height: int,
15
+ width: int,
16
+ device: torch.device,
17
+ dtype: torch.dtype,
18
+ seed: int,
19
+ ):
20
+ return torch.randn(
21
+ num_samples,
22
+ 16,
23
+ # allow for packing
24
+ 2 * math.ceil(height / 16),
25
+ 2 * math.ceil(width / 16),
26
+ device=device,
27
+ dtype=dtype,
28
+ generator=torch.Generator(device=device).manual_seed(seed),
29
+ )
30
+
31
+
32
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
33
+ bs, c, h, w = img.shape
34
+ if bs == 1 and not isinstance(prompt, str):
35
+ bs = len(prompt)
36
+
37
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
38
+ if img.shape[0] == 1 and bs > 1:
39
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
40
+
41
+ img_ids = torch.zeros(h // 2, w // 2, 3)
42
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
43
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
44
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
45
+
46
+ if isinstance(prompt, str):
47
+ prompt = [prompt]
48
+ txt = t5(prompt)
49
+ if txt.shape[0] == 1 and bs > 1:
50
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
51
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
52
+
53
+ vec = clip(prompt)
54
+ if vec.shape[0] == 1 and bs > 1:
55
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
56
+
57
+ return {
58
+ "img": img,
59
+ "img_ids": img_ids.to(img.device),
60
+ "txt": txt.to(img.device),
61
+ "txt_ids": txt_ids.to(img.device),
62
+ "vec": vec.to(img.device),
63
+ }
64
+
65
+
66
+ def time_shift(mu: float, sigma: float, t: Tensor):
67
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
68
+
69
+
70
+ def get_lin_function(
71
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
72
+ ) -> Callable[[float], float]:
73
+ m = (y2 - y1) / (x2 - x1)
74
+ b = y1 - m * x1
75
+ return lambda x: m * x + b
76
+
77
+
78
+ def get_schedule(
79
+ num_steps: int,
80
+ image_seq_len: int,
81
+ base_shift: float = 0.5,
82
+ max_shift: float = 1.15,
83
+ shift: bool = True,
84
+ ) -> list[float]:
85
+ # extra step for zero
86
+ timesteps = torch.linspace(1, 0, num_steps + 1)
87
+
88
+ # shifting the schedule to favor high timesteps for higher signal images
89
+ if shift:
90
+ # eastimate mu based on linear estimation between two points
91
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
92
+ timesteps = time_shift(mu, 1.0, timesteps)
93
+
94
+ return timesteps.tolist()
95
+
96
+
97
+ def denoise(
98
+ model: Flux,
99
+ # model input
100
+ img: Tensor,
101
+ img_ids: Tensor,
102
+ txt: Tensor,
103
+ txt_ids: Tensor,
104
+ vec: Tensor,
105
+ # sampling parameters
106
+ timesteps: list[float],
107
+ guidance: float = 4.0,
108
+ use_gs=False,
109
+ gs=4,
110
+ ):
111
+ # this is ignored for schnell
112
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
113
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
114
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
115
+ pred = model(
116
+ img=img,
117
+ img_ids=img_ids,
118
+ txt=txt,
119
+ txt_ids=txt_ids,
120
+ y=vec,
121
+ timesteps=t_vec,
122
+ guidance=guidance_vec,
123
+ )
124
+ if use_gs:
125
+ pred_uncond, pred_text = pred.chunk(2)
126
+ pred = pred_uncond + gs * (pred_text - pred_uncond)
127
+
128
+ img = img + (t_prev - t_curr) * pred
129
+ #if use_gs:
130
+ # img = torch.cat([img] * 2)
131
+
132
+ return img
133
+
134
+ def denoise_controlnet(
135
+ model: Flux,
136
+ controlnet:None,
137
+ # model input
138
+ img: Tensor,
139
+ img_ids: Tensor,
140
+ txt: Tensor,
141
+ txt_ids: Tensor,
142
+ vec: Tensor,
143
+ controlnet_cond,
144
+ # sampling parameters
145
+ timesteps: list[float],
146
+ guidance: float = 4.0,
147
+ controlnet_gs=0.7,
148
+ ):
149
+ # this is ignored for schnell
150
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
151
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
152
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
153
+ block_res_samples = controlnet(
154
+ img=img,
155
+ img_ids=img_ids,
156
+ controlnet_cond=controlnet_cond,
157
+ txt=txt,
158
+ txt_ids=txt_ids,
159
+ y=vec,
160
+ timesteps=t_vec,
161
+ guidance=guidance_vec,
162
+ )
163
+ pred = model(
164
+ img=img,
165
+ img_ids=img_ids,
166
+ txt=txt,
167
+ txt_ids=txt_ids,
168
+ y=vec,
169
+ timesteps=t_vec,
170
+ guidance=guidance_vec,
171
+ block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples]
172
+ )
173
+
174
+ img = img + (t_prev - t_curr) * pred
175
+ #if use_gs:
176
+ # img = torch.cat([img] * 2)
177
+
178
+ return img
179
+
180
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
181
+ return rearrange(
182
+ x,
183
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
184
+ h=math.ceil(height / 16),
185
+ w=math.ceil(width / 16),
186
+ ph=2,
187
+ pw=2,
188
+ )
src/flux/util.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from safetensors.torch import load_file as load_sft
8
+
9
+ from .model import Flux, FluxParams
10
+ from .controlnet import ControlNetFlux
11
+ from .modules.autoencoder import AutoEncoder, AutoEncoderParams
12
+ from .modules.conditioner import HFEmbedder
13
+
14
+ from safetensors import safe_open
15
+
16
+ def load_safetensors(path):
17
+ tensors = {}
18
+ with safe_open(path, framework="pt", device="cpu") as f:
19
+ for key in f.keys():
20
+ tensors[key] = f.get_tensor(key)
21
+ return tensors
22
+
23
+ @dataclass
24
+ class ModelSpec:
25
+ params: FluxParams
26
+ ae_params: AutoEncoderParams
27
+ ckpt_path: str | None
28
+ ae_path: str | None
29
+ repo_id: str | None
30
+ repo_flow: str | None
31
+ repo_ae: str | None
32
+
33
+
34
+ configs = {
35
+ "flux-dev": ModelSpec(
36
+ repo_id="black-forest-labs/FLUX.1-dev",
37
+ repo_flow="flux1-dev.safetensors",
38
+ repo_ae="ae.safetensors",
39
+ ckpt_path=os.getenv("FLUX_DEV"),
40
+ params=FluxParams(
41
+ in_channels=64,
42
+ vec_in_dim=768,
43
+ context_in_dim=4096,
44
+ hidden_size=3072,
45
+ mlp_ratio=4.0,
46
+ num_heads=24,
47
+ depth=19,
48
+ depth_single_blocks=38,
49
+ axes_dim=[16, 56, 56],
50
+ theta=10_000,
51
+ qkv_bias=True,
52
+ guidance_embed=True,
53
+ ),
54
+ ae_path=os.getenv("AE"),
55
+ ae_params=AutoEncoderParams(
56
+ resolution=256,
57
+ in_channels=3,
58
+ ch=128,
59
+ out_ch=3,
60
+ ch_mult=[1, 2, 4, 4],
61
+ num_res_blocks=2,
62
+ z_channels=16,
63
+ scale_factor=0.3611,
64
+ shift_factor=0.1159,
65
+ ),
66
+ ),
67
+ "flux-schnell": ModelSpec(
68
+ repo_id="black-forest-labs/FLUX.1-schnell",
69
+ repo_flow="flux1-schnell.safetensors",
70
+ repo_ae="ae.safetensors",
71
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
72
+ params=FluxParams(
73
+ in_channels=64,
74
+ vec_in_dim=768,
75
+ context_in_dim=4096,
76
+ hidden_size=3072,
77
+ mlp_ratio=4.0,
78
+ num_heads=24,
79
+ depth=19,
80
+ depth_single_blocks=38,
81
+ axes_dim=[16, 56, 56],
82
+ theta=10_000,
83
+ qkv_bias=True,
84
+ guidance_embed=False,
85
+ ),
86
+ ae_path=os.getenv("AE"),
87
+ ae_params=AutoEncoderParams(
88
+ resolution=256,
89
+ in_channels=3,
90
+ ch=128,
91
+ out_ch=3,
92
+ ch_mult=[1, 2, 4, 4],
93
+ num_res_blocks=2,
94
+ z_channels=16,
95
+ scale_factor=0.3611,
96
+ shift_factor=0.1159,
97
+ ),
98
+ ),
99
+ }
100
+
101
+
102
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
103
+ if len(missing) > 0 and len(unexpected) > 0:
104
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
105
+ print("\n" + "-" * 79 + "\n")
106
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
107
+ elif len(missing) > 0:
108
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
109
+ elif len(unexpected) > 0:
110
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
111
+
112
+
113
+ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
114
+ # Loading Flux
115
+ print("Init model")
116
+ ckpt_path = configs[name].ckpt_path
117
+ if (
118
+ ckpt_path is None
119
+ and configs[name].repo_id is not None
120
+ and configs[name].repo_flow is not None
121
+ and hf_download
122
+ ):
123
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
124
+
125
+ with torch.device("meta" if ckpt_path is not None else device):
126
+ model = Flux(configs[name].params).to(torch.bfloat16)
127
+
128
+ if ckpt_path is not None:
129
+ print("Loading checkpoint")
130
+ # load_sft doesn't support torch.device
131
+ sd = load_sft(ckpt_path, device=str(device))
132
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
133
+ print_load_warning(missing, unexpected)
134
+ return model
135
+
136
+ def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
137
+ # Loading Flux
138
+ print("Init model")
139
+ ckpt_path = configs[name].ckpt_path
140
+ if (
141
+ ckpt_path is None
142
+ and configs[name].repo_id is not None
143
+ and configs[name].repo_flow is not None
144
+ and hf_download
145
+ ):
146
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
147
+
148
+ with torch.device("meta" if ckpt_path is not None else device):
149
+ model = Flux(configs[name].params)
150
+
151
+ if ckpt_path is not None:
152
+ print("Loading checkpoint")
153
+ # load_sft doesn't support torch.device
154
+ sd = load_sft(ckpt_path, device=str(device))
155
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
156
+ print_load_warning(missing, unexpected)
157
+ return model
158
+
159
+ def load_controlnet(name, device, transformer=None):
160
+ with torch.device(device):
161
+ controlnet = ControlNetFlux(configs[name].params)
162
+ if transformer is not None:
163
+ controlnet.load_state_dict(transformer.state_dict(), strict=False)
164
+ return controlnet
165
+
166
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
167
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
168
+ return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
169
+
170
+
171
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
172
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
173
+
174
+
175
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
176
+ ckpt_path = configs[name].ae_path
177
+ if (
178
+ ckpt_path is None
179
+ and configs[name].repo_id is not None
180
+ and configs[name].repo_ae is not None
181
+ and hf_download
182
+ ):
183
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
184
+
185
+ # Loading the autoencoder
186
+ print("Init AE")
187
+ with torch.device("meta" if ckpt_path is not None else device):
188
+ ae = AutoEncoder(configs[name].ae_params)
189
+
190
+ if ckpt_path is not None:
191
+ sd = load_sft(ckpt_path, device=str(device))
192
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
193
+ print_load_warning(missing, unexpected)
194
+ return ae
195
+
196
+
197
+ class WatermarkEmbedder:
198
+ def __init__(self, watermark):
199
+ self.watermark = watermark
200
+ self.num_bits = len(WATERMARK_BITS)
201
+ self.encoder = WatermarkEncoder()
202
+ self.encoder.set_watermark("bits", self.watermark)
203
+
204
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
205
+ """
206
+ Adds a predefined watermark to the input image
207
+
208
+ Args:
209
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
210
+
211
+ Returns:
212
+ same as input but watermarked
213
+ """
214
+ image = 0.5 * image + 0.5
215
+ squeeze = len(image.shape) == 4
216
+ if squeeze:
217
+ image = image[None, ...]
218
+ n = image.shape[0]
219
+ image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
220
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
221
+ # watermarking libary expects input as cv2 BGR format
222
+ for k in range(image_np.shape[0]):
223
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
224
+ image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
225
+ image.device
226
+ )
227
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
228
+ if squeeze:
229
+ image = image[0]
230
+ image = 2 * image - 1
231
+ return image
232
+
233
+
234
+ # A fixed 48-bit message that was choosen at random
235
+ WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
236
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
237
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]