Spaces:
Running
on
Zero
Running
on
Zero
Upload 9 files
Browse files- src/flux/__init__.py +11 -0
- src/flux/__main__.py +4 -0
- src/flux/api.py +194 -0
- src/flux/cli.py +254 -0
- src/flux/controlnet.py +222 -0
- src/flux/math.py +30 -0
- src/flux/model.py +217 -0
- src/flux/sampling.py +188 -0
- src/flux/util.py +237 -0
src/flux/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from ._version import version as __version__ # type: ignore
|
3 |
+
from ._version import version_tuple
|
4 |
+
except ImportError:
|
5 |
+
__version__ = "unknown (no version information available)"
|
6 |
+
version_tuple = (0, 0, "unknown", "noinfo")
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
PACKAGE = __package__.replace("_", "-")
|
11 |
+
PACKAGE_ROOT = Path(__file__).parent
|
src/flux/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cli import app
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
app()
|
src/flux/api.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
API_ENDPOINT = "https://api.bfl.ml"
|
10 |
+
|
11 |
+
|
12 |
+
class ApiException(Exception):
|
13 |
+
def __init__(self, status_code: int, detail: str | list[dict] | None = None):
|
14 |
+
super().__init__()
|
15 |
+
self.detail = detail
|
16 |
+
self.status_code = status_code
|
17 |
+
|
18 |
+
def __str__(self) -> str:
|
19 |
+
return self.__repr__()
|
20 |
+
|
21 |
+
def __repr__(self) -> str:
|
22 |
+
if self.detail is None:
|
23 |
+
message = None
|
24 |
+
elif isinstance(self.detail, str):
|
25 |
+
message = self.detail
|
26 |
+
else:
|
27 |
+
message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
|
28 |
+
return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
|
29 |
+
|
30 |
+
|
31 |
+
class ImageRequest:
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
prompt: str,
|
35 |
+
width: int = 1024,
|
36 |
+
height: int = 1024,
|
37 |
+
name: str = "flux.1-pro",
|
38 |
+
num_steps: int = 50,
|
39 |
+
prompt_upsampling: bool = False,
|
40 |
+
seed: int | None = None,
|
41 |
+
validate: bool = True,
|
42 |
+
launch: bool = True,
|
43 |
+
api_key: str | None = None,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Manages an image generation request to the API.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
prompt: Prompt to sample
|
50 |
+
width: Width of the image in pixel
|
51 |
+
height: Height of the image in pixel
|
52 |
+
name: Name of the model
|
53 |
+
num_steps: Number of network evaluations
|
54 |
+
prompt_upsampling: Use prompt upsampling
|
55 |
+
seed: Fix the generation seed
|
56 |
+
validate: Run input validation
|
57 |
+
launch: Directly launches request
|
58 |
+
api_key: Your API key if not provided by the environment
|
59 |
+
|
60 |
+
Raises:
|
61 |
+
ValueError: For invalid input
|
62 |
+
ApiException: For errors raised from the API
|
63 |
+
"""
|
64 |
+
if validate:
|
65 |
+
if name not in ["flux.1-pro"]:
|
66 |
+
raise ValueError(f"Invalid model {name}")
|
67 |
+
elif width % 32 != 0:
|
68 |
+
raise ValueError(f"width must be divisible by 32, got {width}")
|
69 |
+
elif not (256 <= width <= 1440):
|
70 |
+
raise ValueError(f"width must be between 256 and 1440, got {width}")
|
71 |
+
elif height % 32 != 0:
|
72 |
+
raise ValueError(f"height must be divisible by 32, got {height}")
|
73 |
+
elif not (256 <= height <= 1440):
|
74 |
+
raise ValueError(f"height must be between 256 and 1440, got {height}")
|
75 |
+
elif not (1 <= num_steps <= 50):
|
76 |
+
raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
|
77 |
+
|
78 |
+
self.request_json = {
|
79 |
+
"prompt": prompt,
|
80 |
+
"width": width,
|
81 |
+
"height": height,
|
82 |
+
"variant": name,
|
83 |
+
"steps": num_steps,
|
84 |
+
"prompt_upsampling": prompt_upsampling,
|
85 |
+
}
|
86 |
+
if seed is not None:
|
87 |
+
self.request_json["seed"] = seed
|
88 |
+
|
89 |
+
self.request_id: str | None = None
|
90 |
+
self.result: dict | None = None
|
91 |
+
self._image_bytes: bytes | None = None
|
92 |
+
self._url: str | None = None
|
93 |
+
if api_key is None:
|
94 |
+
self.api_key = os.environ.get("BFL_API_KEY")
|
95 |
+
else:
|
96 |
+
self.api_key = api_key
|
97 |
+
|
98 |
+
if launch:
|
99 |
+
self.request()
|
100 |
+
|
101 |
+
def request(self):
|
102 |
+
"""
|
103 |
+
Request to generate the image.
|
104 |
+
"""
|
105 |
+
if self.request_id is not None:
|
106 |
+
return
|
107 |
+
response = requests.post(
|
108 |
+
f"{API_ENDPOINT}/v1/image",
|
109 |
+
headers={
|
110 |
+
"accept": "application/json",
|
111 |
+
"x-key": self.api_key,
|
112 |
+
"Content-Type": "application/json",
|
113 |
+
},
|
114 |
+
json=self.request_json,
|
115 |
+
)
|
116 |
+
result = response.json()
|
117 |
+
if response.status_code != 200:
|
118 |
+
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
|
119 |
+
self.request_id = response.json()["id"]
|
120 |
+
|
121 |
+
def retrieve(self) -> dict:
|
122 |
+
"""
|
123 |
+
Wait for the generation to finish and retrieve response.
|
124 |
+
"""
|
125 |
+
if self.request_id is None:
|
126 |
+
self.request()
|
127 |
+
while self.result is None:
|
128 |
+
response = requests.get(
|
129 |
+
f"{API_ENDPOINT}/v1/get_result",
|
130 |
+
headers={
|
131 |
+
"accept": "application/json",
|
132 |
+
"x-key": self.api_key,
|
133 |
+
},
|
134 |
+
params={
|
135 |
+
"id": self.request_id,
|
136 |
+
},
|
137 |
+
)
|
138 |
+
result = response.json()
|
139 |
+
if "status" not in result:
|
140 |
+
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
|
141 |
+
elif result["status"] == "Ready":
|
142 |
+
self.result = result["result"]
|
143 |
+
elif result["status"] == "Pending":
|
144 |
+
time.sleep(0.5)
|
145 |
+
else:
|
146 |
+
raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
|
147 |
+
return self.result
|
148 |
+
|
149 |
+
@property
|
150 |
+
def bytes(self) -> bytes:
|
151 |
+
"""
|
152 |
+
Generated image as bytes.
|
153 |
+
"""
|
154 |
+
if self._image_bytes is None:
|
155 |
+
response = requests.get(self.url)
|
156 |
+
if response.status_code == 200:
|
157 |
+
self._image_bytes = response.content
|
158 |
+
else:
|
159 |
+
raise ApiException(status_code=response.status_code)
|
160 |
+
return self._image_bytes
|
161 |
+
|
162 |
+
@property
|
163 |
+
def url(self) -> str:
|
164 |
+
"""
|
165 |
+
Public url to retrieve the image from
|
166 |
+
"""
|
167 |
+
if self._url is None:
|
168 |
+
result = self.retrieve()
|
169 |
+
self._url = result["sample"]
|
170 |
+
return self._url
|
171 |
+
|
172 |
+
@property
|
173 |
+
def image(self) -> Image.Image:
|
174 |
+
"""
|
175 |
+
Load the image as a PIL Image
|
176 |
+
"""
|
177 |
+
return Image.open(io.BytesIO(self.bytes))
|
178 |
+
|
179 |
+
def save(self, path: str):
|
180 |
+
"""
|
181 |
+
Save the generated image to a local path
|
182 |
+
"""
|
183 |
+
suffix = Path(self.url).suffix
|
184 |
+
if not path.endswith(suffix):
|
185 |
+
path = path + suffix
|
186 |
+
Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
|
187 |
+
with open(path, "wb") as file:
|
188 |
+
file.write(self.bytes)
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
from fire import Fire
|
193 |
+
|
194 |
+
Fire(ImageRequest)
|
src/flux/cli.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from glob import iglob
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from einops import rearrange
|
9 |
+
from fire import Fire
|
10 |
+
from PIL import ExifTags, Image
|
11 |
+
|
12 |
+
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
13 |
+
from flux.util import (configs, embed_watermark, load_ae, load_clip,
|
14 |
+
load_flow_model, load_t5)
|
15 |
+
from transformers import pipeline
|
16 |
+
|
17 |
+
NSFW_THRESHOLD = 0.85
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class SamplingOptions:
|
21 |
+
prompt: str
|
22 |
+
width: int
|
23 |
+
height: int
|
24 |
+
num_steps: int
|
25 |
+
guidance: float
|
26 |
+
seed: int | None
|
27 |
+
|
28 |
+
|
29 |
+
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
30 |
+
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
31 |
+
usage = (
|
32 |
+
"Usage: Either write your prompt directly, leave this field empty "
|
33 |
+
"to repeat the prompt or write a command starting with a slash:\n"
|
34 |
+
"- '/w <width>' will set the width of the generated image\n"
|
35 |
+
"- '/h <height>' will set the height of the generated image\n"
|
36 |
+
"- '/s <seed>' sets the next seed\n"
|
37 |
+
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
38 |
+
"- '/n <steps>' sets the number of steps\n"
|
39 |
+
"- '/q' to quit"
|
40 |
+
)
|
41 |
+
|
42 |
+
while (prompt := input(user_question)).startswith("/"):
|
43 |
+
if prompt.startswith("/w"):
|
44 |
+
if prompt.count(" ") != 1:
|
45 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
46 |
+
continue
|
47 |
+
_, width = prompt.split()
|
48 |
+
options.width = 16 * (int(width) // 16)
|
49 |
+
print(
|
50 |
+
f"Setting resolution to {options.width} x {options.height} "
|
51 |
+
f"({options.height *options.width/1e6:.2f}MP)"
|
52 |
+
)
|
53 |
+
elif prompt.startswith("/h"):
|
54 |
+
if prompt.count(" ") != 1:
|
55 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
56 |
+
continue
|
57 |
+
_, height = prompt.split()
|
58 |
+
options.height = 16 * (int(height) // 16)
|
59 |
+
print(
|
60 |
+
f"Setting resolution to {options.width} x {options.height} "
|
61 |
+
f"({options.height *options.width/1e6:.2f}MP)"
|
62 |
+
)
|
63 |
+
elif prompt.startswith("/g"):
|
64 |
+
if prompt.count(" ") != 1:
|
65 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
66 |
+
continue
|
67 |
+
_, guidance = prompt.split()
|
68 |
+
options.guidance = float(guidance)
|
69 |
+
print(f"Setting guidance to {options.guidance}")
|
70 |
+
elif prompt.startswith("/s"):
|
71 |
+
if prompt.count(" ") != 1:
|
72 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
73 |
+
continue
|
74 |
+
_, seed = prompt.split()
|
75 |
+
options.seed = int(seed)
|
76 |
+
print(f"Setting seed to {options.seed}")
|
77 |
+
elif prompt.startswith("/n"):
|
78 |
+
if prompt.count(" ") != 1:
|
79 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
80 |
+
continue
|
81 |
+
_, steps = prompt.split()
|
82 |
+
options.num_steps = int(steps)
|
83 |
+
print(f"Setting seed to {options.num_steps}")
|
84 |
+
elif prompt.startswith("/q"):
|
85 |
+
print("Quitting")
|
86 |
+
return None
|
87 |
+
else:
|
88 |
+
if not prompt.startswith("/h"):
|
89 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
90 |
+
print(usage)
|
91 |
+
if prompt != "":
|
92 |
+
options.prompt = prompt
|
93 |
+
return options
|
94 |
+
|
95 |
+
|
96 |
+
@torch.inference_mode()
|
97 |
+
def main(
|
98 |
+
name: str = "flux-schnell",
|
99 |
+
width: int = 1360,
|
100 |
+
height: int = 768,
|
101 |
+
seed: int | None = None,
|
102 |
+
prompt: str = (
|
103 |
+
"a photo of a forest with mist swirling around the tree trunks. The word "
|
104 |
+
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
105 |
+
),
|
106 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
107 |
+
num_steps: int | None = None,
|
108 |
+
loop: bool = False,
|
109 |
+
guidance: float = 3.5,
|
110 |
+
offload: bool = False,
|
111 |
+
output_dir: str = "output",
|
112 |
+
add_sampling_metadata: bool = True,
|
113 |
+
):
|
114 |
+
"""
|
115 |
+
Sample the flux model. Either interactively (set `--loop`) or run for a
|
116 |
+
single image.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
name: Name of the model to load
|
120 |
+
height: height of the sample in pixels (should be a multiple of 16)
|
121 |
+
width: width of the sample in pixels (should be a multiple of 16)
|
122 |
+
seed: Set a seed for sampling
|
123 |
+
output_name: where to save the output image, `{idx}` will be replaced
|
124 |
+
by the index of the sample
|
125 |
+
prompt: Prompt used for sampling
|
126 |
+
device: Pytorch device
|
127 |
+
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
128 |
+
loop: start an interactive session and sample multiple times
|
129 |
+
guidance: guidance value used for guidance distillation
|
130 |
+
add_sampling_metadata: Add the prompt to the image Exif metadata
|
131 |
+
"""
|
132 |
+
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
|
133 |
+
|
134 |
+
if name not in configs:
|
135 |
+
available = ", ".join(configs.keys())
|
136 |
+
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
137 |
+
|
138 |
+
torch_device = torch.device(device)
|
139 |
+
if num_steps is None:
|
140 |
+
num_steps = 4 if name == "flux-schnell" else 50
|
141 |
+
|
142 |
+
# allow for packing and conversion to latent space
|
143 |
+
height = 16 * (height // 16)
|
144 |
+
width = 16 * (width // 16)
|
145 |
+
|
146 |
+
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
147 |
+
if not os.path.exists(output_dir):
|
148 |
+
os.makedirs(output_dir)
|
149 |
+
idx = 0
|
150 |
+
else:
|
151 |
+
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)]
|
152 |
+
if len(fns) > 0:
|
153 |
+
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
154 |
+
else:
|
155 |
+
idx = 0
|
156 |
+
|
157 |
+
# init all components
|
158 |
+
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
|
159 |
+
clip = load_clip(torch_device)
|
160 |
+
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
161 |
+
ae = load_ae(name, device="cpu" if offload else torch_device)
|
162 |
+
|
163 |
+
rng = torch.Generator(device="cpu")
|
164 |
+
opts = SamplingOptions(
|
165 |
+
prompt=prompt,
|
166 |
+
width=width,
|
167 |
+
height=height,
|
168 |
+
num_steps=num_steps,
|
169 |
+
guidance=guidance,
|
170 |
+
seed=seed,
|
171 |
+
)
|
172 |
+
|
173 |
+
if loop:
|
174 |
+
opts = parse_prompt(opts)
|
175 |
+
|
176 |
+
while opts is not None:
|
177 |
+
if opts.seed is None:
|
178 |
+
opts.seed = rng.seed()
|
179 |
+
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
180 |
+
t0 = time.perf_counter()
|
181 |
+
|
182 |
+
# prepare input
|
183 |
+
x = get_noise(
|
184 |
+
1,
|
185 |
+
opts.height,
|
186 |
+
opts.width,
|
187 |
+
device=torch_device,
|
188 |
+
dtype=torch.bfloat16,
|
189 |
+
seed=opts.seed,
|
190 |
+
)
|
191 |
+
opts.seed = None
|
192 |
+
if offload:
|
193 |
+
ae = ae.cpu()
|
194 |
+
torch.cuda.empty_cache()
|
195 |
+
t5, clip = t5.to(torch_device), clip.to(torch_device)
|
196 |
+
inp = prepare(t5, clip, x, prompt=opts.prompt)
|
197 |
+
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
198 |
+
|
199 |
+
# offload TEs to CPU, load model to gpu
|
200 |
+
if offload:
|
201 |
+
t5, clip = t5.cpu(), clip.cpu()
|
202 |
+
torch.cuda.empty_cache()
|
203 |
+
model = model.to(torch_device)
|
204 |
+
|
205 |
+
# denoise initial noise
|
206 |
+
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
207 |
+
|
208 |
+
# offload model, load autoencoder to gpu
|
209 |
+
if offload:
|
210 |
+
model.cpu()
|
211 |
+
torch.cuda.empty_cache()
|
212 |
+
ae.decoder.to(x.device)
|
213 |
+
|
214 |
+
# decode latents to pixel space
|
215 |
+
x = unpack(x.float(), opts.height, opts.width)
|
216 |
+
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
217 |
+
x = ae.decode(x)
|
218 |
+
t1 = time.perf_counter()
|
219 |
+
|
220 |
+
fn = output_name.format(idx=idx)
|
221 |
+
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
|
222 |
+
# bring into PIL format and save
|
223 |
+
x = x.clamp(-1, 1)
|
224 |
+
x = embed_watermark(x.float())
|
225 |
+
x = rearrange(x[0], "c h w -> h w c")
|
226 |
+
|
227 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
228 |
+
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
|
229 |
+
|
230 |
+
if nsfw_score < NSFW_THRESHOLD:
|
231 |
+
exif_data = Image.Exif()
|
232 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
233 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
234 |
+
exif_data[ExifTags.Base.Model] = name
|
235 |
+
if add_sampling_metadata:
|
236 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
237 |
+
img.save(fn, exif=exif_data, quality=95, subsampling=0)
|
238 |
+
idx += 1
|
239 |
+
else:
|
240 |
+
print("Your generated image may contain NSFW content.")
|
241 |
+
|
242 |
+
if loop:
|
243 |
+
print("-" * 80)
|
244 |
+
opts = parse_prompt(opts)
|
245 |
+
else:
|
246 |
+
opts = None
|
247 |
+
|
248 |
+
|
249 |
+
def app():
|
250 |
+
Fire(main)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
app()
|
src/flux/controlnet.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor, nn
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
8 |
+
MLPEmbedder, SingleStreamBlock,
|
9 |
+
timestep_embedding)
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class FluxParams:
|
14 |
+
in_channels: int
|
15 |
+
vec_in_dim: int
|
16 |
+
context_in_dim: int
|
17 |
+
hidden_size: int
|
18 |
+
mlp_ratio: float
|
19 |
+
num_heads: int
|
20 |
+
depth: int
|
21 |
+
depth_single_blocks: int
|
22 |
+
axes_dim: list[int]
|
23 |
+
theta: int
|
24 |
+
qkv_bias: bool
|
25 |
+
guidance_embed: bool
|
26 |
+
|
27 |
+
def zero_module(module):
|
28 |
+
for p in module.parameters():
|
29 |
+
nn.init.zeros_(p)
|
30 |
+
return module
|
31 |
+
|
32 |
+
|
33 |
+
class ControlNetFlux(nn.Module):
|
34 |
+
"""
|
35 |
+
Transformer model for flow matching on sequences.
|
36 |
+
"""
|
37 |
+
_supports_gradient_checkpointing = True
|
38 |
+
|
39 |
+
def __init__(self, params: FluxParams, controlnet_depth=2):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.params = params
|
43 |
+
self.in_channels = params.in_channels
|
44 |
+
self.out_channels = self.in_channels
|
45 |
+
if params.hidden_size % params.num_heads != 0:
|
46 |
+
raise ValueError(
|
47 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
48 |
+
)
|
49 |
+
pe_dim = params.hidden_size // params.num_heads
|
50 |
+
if sum(params.axes_dim) != pe_dim:
|
51 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
52 |
+
self.hidden_size = params.hidden_size
|
53 |
+
self.num_heads = params.num_heads
|
54 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
55 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
56 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
57 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
58 |
+
self.guidance_in = (
|
59 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
60 |
+
)
|
61 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
62 |
+
|
63 |
+
self.double_blocks = nn.ModuleList(
|
64 |
+
[
|
65 |
+
DoubleStreamBlock(
|
66 |
+
self.hidden_size,
|
67 |
+
self.num_heads,
|
68 |
+
mlp_ratio=params.mlp_ratio,
|
69 |
+
qkv_bias=params.qkv_bias,
|
70 |
+
)
|
71 |
+
for _ in range(controlnet_depth)
|
72 |
+
]
|
73 |
+
)
|
74 |
+
|
75 |
+
# add ControlNet blocks
|
76 |
+
self.controlnet_blocks = nn.ModuleList([])
|
77 |
+
for _ in range(controlnet_depth):
|
78 |
+
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
79 |
+
controlnet_block = zero_module(controlnet_block)
|
80 |
+
self.controlnet_blocks.append(controlnet_block)
|
81 |
+
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
82 |
+
self.gradient_checkpointing = False
|
83 |
+
self.input_hint_block = nn.Sequential(
|
84 |
+
nn.Conv2d(3, 16, 3, padding=1),
|
85 |
+
nn.SiLU(),
|
86 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
87 |
+
nn.SiLU(),
|
88 |
+
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
89 |
+
nn.SiLU(),
|
90 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
91 |
+
nn.SiLU(),
|
92 |
+
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
93 |
+
nn.SiLU(),
|
94 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
95 |
+
nn.SiLU(),
|
96 |
+
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
97 |
+
nn.SiLU(),
|
98 |
+
zero_module(nn.Conv2d(16, 16, 3, padding=1))
|
99 |
+
)
|
100 |
+
|
101 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
102 |
+
if hasattr(module, "gradient_checkpointing"):
|
103 |
+
module.gradient_checkpointing = value
|
104 |
+
|
105 |
+
|
106 |
+
@property
|
107 |
+
def attn_processors(self):
|
108 |
+
# set recursively
|
109 |
+
processors = {}
|
110 |
+
|
111 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
112 |
+
if hasattr(module, "set_processor"):
|
113 |
+
processors[f"{name}.processor"] = module.processor
|
114 |
+
|
115 |
+
for sub_name, child in module.named_children():
|
116 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
117 |
+
|
118 |
+
return processors
|
119 |
+
|
120 |
+
for name, module in self.named_children():
|
121 |
+
fn_recursive_add_processors(name, module, processors)
|
122 |
+
|
123 |
+
return processors
|
124 |
+
|
125 |
+
def set_attn_processor(self, processor):
|
126 |
+
r"""
|
127 |
+
Sets the attention processor to use to compute attention.
|
128 |
+
|
129 |
+
Parameters:
|
130 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
131 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
132 |
+
for **all** `Attention` layers.
|
133 |
+
|
134 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
135 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
136 |
+
|
137 |
+
"""
|
138 |
+
count = len(self.attn_processors.keys())
|
139 |
+
|
140 |
+
if isinstance(processor, dict) and len(processor) != count:
|
141 |
+
raise ValueError(
|
142 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
143 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
144 |
+
)
|
145 |
+
|
146 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
147 |
+
if hasattr(module, "set_processor"):
|
148 |
+
if not isinstance(processor, dict):
|
149 |
+
module.set_processor(processor)
|
150 |
+
else:
|
151 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
152 |
+
|
153 |
+
for sub_name, child in module.named_children():
|
154 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
155 |
+
|
156 |
+
for name, module in self.named_children():
|
157 |
+
fn_recursive_attn_processor(name, module, processor)
|
158 |
+
|
159 |
+
def forward(
|
160 |
+
self,
|
161 |
+
img: Tensor,
|
162 |
+
img_ids: Tensor,
|
163 |
+
controlnet_cond: Tensor,
|
164 |
+
txt: Tensor,
|
165 |
+
txt_ids: Tensor,
|
166 |
+
timesteps: Tensor,
|
167 |
+
y: Tensor,
|
168 |
+
guidance: Tensor | None = None,
|
169 |
+
) -> Tensor:
|
170 |
+
if img.ndim != 3 or txt.ndim != 3:
|
171 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
172 |
+
|
173 |
+
# running on sequences img
|
174 |
+
img = self.img_in(img)
|
175 |
+
controlnet_cond = self.input_hint_block(controlnet_cond)
|
176 |
+
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
177 |
+
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
178 |
+
img = img + controlnet_cond
|
179 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
180 |
+
if self.params.guidance_embed:
|
181 |
+
if guidance is None:
|
182 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
183 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
184 |
+
vec = vec + self.vector_in(y)
|
185 |
+
txt = self.txt_in(txt)
|
186 |
+
|
187 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
188 |
+
pe = self.pe_embedder(ids)
|
189 |
+
|
190 |
+
block_res_samples = ()
|
191 |
+
|
192 |
+
for block in self.double_blocks:
|
193 |
+
if self.training and self.gradient_checkpointing:
|
194 |
+
|
195 |
+
def create_custom_forward(module, return_dict=None):
|
196 |
+
def custom_forward(*inputs):
|
197 |
+
if return_dict is not None:
|
198 |
+
return module(*inputs, return_dict=return_dict)
|
199 |
+
else:
|
200 |
+
return module(*inputs)
|
201 |
+
|
202 |
+
return custom_forward
|
203 |
+
|
204 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
205 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
206 |
+
create_custom_forward(block),
|
207 |
+
img,
|
208 |
+
txt,
|
209 |
+
vec,
|
210 |
+
pe,
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
214 |
+
|
215 |
+
block_res_samples = block_res_samples + (img,)
|
216 |
+
|
217 |
+
controlnet_block_res_samples = ()
|
218 |
+
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
219 |
+
block_res_sample = controlnet_block(block_res_sample)
|
220 |
+
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
221 |
+
|
222 |
+
return controlnet_block_res_samples
|
src/flux/math.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
|
6 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
7 |
+
q, k = apply_rope(q, k, pe)
|
8 |
+
|
9 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
10 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
11 |
+
|
12 |
+
return x
|
13 |
+
|
14 |
+
|
15 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
16 |
+
assert dim % 2 == 0
|
17 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
18 |
+
omega = 1.0 / (theta**scale)
|
19 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
20 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
21 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
22 |
+
return out.float()
|
23 |
+
|
24 |
+
|
25 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
26 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
27 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
28 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
29 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
30 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
src/flux/model.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor, nn
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
8 |
+
MLPEmbedder, SingleStreamBlock,
|
9 |
+
timestep_embedding)
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class FluxParams:
|
14 |
+
in_channels: int
|
15 |
+
vec_in_dim: int
|
16 |
+
context_in_dim: int
|
17 |
+
hidden_size: int
|
18 |
+
mlp_ratio: float
|
19 |
+
num_heads: int
|
20 |
+
depth: int
|
21 |
+
depth_single_blocks: int
|
22 |
+
axes_dim: list[int]
|
23 |
+
theta: int
|
24 |
+
qkv_bias: bool
|
25 |
+
guidance_embed: bool
|
26 |
+
|
27 |
+
|
28 |
+
class Flux(nn.Module):
|
29 |
+
"""
|
30 |
+
Transformer model for flow matching on sequences.
|
31 |
+
"""
|
32 |
+
_supports_gradient_checkpointing = True
|
33 |
+
|
34 |
+
def __init__(self, params: FluxParams):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.params = params
|
38 |
+
self.in_channels = params.in_channels
|
39 |
+
self.out_channels = self.in_channels
|
40 |
+
if params.hidden_size % params.num_heads != 0:
|
41 |
+
raise ValueError(
|
42 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
43 |
+
)
|
44 |
+
pe_dim = params.hidden_size // params.num_heads
|
45 |
+
if sum(params.axes_dim) != pe_dim:
|
46 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
47 |
+
self.hidden_size = params.hidden_size
|
48 |
+
self.num_heads = params.num_heads
|
49 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
50 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
51 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
52 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
53 |
+
self.guidance_in = (
|
54 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
55 |
+
)
|
56 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
57 |
+
|
58 |
+
self.double_blocks = nn.ModuleList(
|
59 |
+
[
|
60 |
+
DoubleStreamBlock(
|
61 |
+
self.hidden_size,
|
62 |
+
self.num_heads,
|
63 |
+
mlp_ratio=params.mlp_ratio,
|
64 |
+
qkv_bias=params.qkv_bias,
|
65 |
+
)
|
66 |
+
for _ in range(params.depth)
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
self.single_blocks = nn.ModuleList(
|
71 |
+
[
|
72 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
73 |
+
for _ in range(params.depth_single_blocks)
|
74 |
+
]
|
75 |
+
)
|
76 |
+
|
77 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
78 |
+
self.gradient_checkpointing = False
|
79 |
+
|
80 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
81 |
+
if hasattr(module, "gradient_checkpointing"):
|
82 |
+
module.gradient_checkpointing = value
|
83 |
+
|
84 |
+
@property
|
85 |
+
def attn_processors(self):
|
86 |
+
# set recursively
|
87 |
+
processors = {}
|
88 |
+
|
89 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
90 |
+
if hasattr(module, "set_processor"):
|
91 |
+
processors[f"{name}.processor"] = module.processor
|
92 |
+
|
93 |
+
for sub_name, child in module.named_children():
|
94 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
95 |
+
|
96 |
+
return processors
|
97 |
+
|
98 |
+
for name, module in self.named_children():
|
99 |
+
fn_recursive_add_processors(name, module, processors)
|
100 |
+
|
101 |
+
return processors
|
102 |
+
|
103 |
+
def set_attn_processor(self, processor):
|
104 |
+
r"""
|
105 |
+
Sets the attention processor to use to compute attention.
|
106 |
+
|
107 |
+
Parameters:
|
108 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
109 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
110 |
+
for **all** `Attention` layers.
|
111 |
+
|
112 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
113 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
114 |
+
|
115 |
+
"""
|
116 |
+
count = len(self.attn_processors.keys())
|
117 |
+
|
118 |
+
if isinstance(processor, dict) and len(processor) != count:
|
119 |
+
raise ValueError(
|
120 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
121 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
122 |
+
)
|
123 |
+
|
124 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
125 |
+
if hasattr(module, "set_processor"):
|
126 |
+
if not isinstance(processor, dict):
|
127 |
+
module.set_processor(processor)
|
128 |
+
else:
|
129 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
130 |
+
|
131 |
+
for sub_name, child in module.named_children():
|
132 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
133 |
+
|
134 |
+
for name, module in self.named_children():
|
135 |
+
fn_recursive_attn_processor(name, module, processor)
|
136 |
+
|
137 |
+
def forward(
|
138 |
+
self,
|
139 |
+
img: Tensor,
|
140 |
+
img_ids: Tensor,
|
141 |
+
txt: Tensor,
|
142 |
+
txt_ids: Tensor,
|
143 |
+
timesteps: Tensor,
|
144 |
+
y: Tensor,
|
145 |
+
block_controlnet_hidden_states=None,
|
146 |
+
guidance: Tensor | None = None,
|
147 |
+
) -> Tensor:
|
148 |
+
if img.ndim != 3 or txt.ndim != 3:
|
149 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
150 |
+
|
151 |
+
# running on sequences img
|
152 |
+
img = self.img_in(img)
|
153 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
154 |
+
if self.params.guidance_embed:
|
155 |
+
if guidance is None:
|
156 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
157 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
158 |
+
vec = vec + self.vector_in(y)
|
159 |
+
txt = self.txt_in(txt)
|
160 |
+
|
161 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
162 |
+
pe = self.pe_embedder(ids)
|
163 |
+
if block_controlnet_hidden_states is not None:
|
164 |
+
controlnet_depth = len(block_controlnet_hidden_states)
|
165 |
+
for index_block, block in enumerate(self.double_blocks):
|
166 |
+
if self.training and self.gradient_checkpointing:
|
167 |
+
|
168 |
+
def create_custom_forward(module, return_dict=None):
|
169 |
+
def custom_forward(*inputs):
|
170 |
+
if return_dict is not None:
|
171 |
+
return module(*inputs, return_dict=return_dict)
|
172 |
+
else:
|
173 |
+
return module(*inputs)
|
174 |
+
|
175 |
+
return custom_forward
|
176 |
+
|
177 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
178 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
179 |
+
create_custom_forward(block),
|
180 |
+
img,
|
181 |
+
txt,
|
182 |
+
vec,
|
183 |
+
pe,
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
187 |
+
# controlnet residual
|
188 |
+
if block_controlnet_hidden_states is not None:
|
189 |
+
img = img + block_controlnet_hidden_states[index_block % 2]
|
190 |
+
|
191 |
+
|
192 |
+
img = torch.cat((txt, img), 1)
|
193 |
+
for block in self.single_blocks:
|
194 |
+
if self.training and self.gradient_checkpointing:
|
195 |
+
|
196 |
+
def create_custom_forward(module, return_dict=None):
|
197 |
+
def custom_forward(*inputs):
|
198 |
+
if return_dict is not None:
|
199 |
+
return module(*inputs, return_dict=return_dict)
|
200 |
+
else:
|
201 |
+
return module(*inputs)
|
202 |
+
|
203 |
+
return custom_forward
|
204 |
+
|
205 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
206 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
207 |
+
create_custom_forward(block),
|
208 |
+
img,
|
209 |
+
vec,
|
210 |
+
pe,
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
img = block(img, vec=vec, pe=pe)
|
214 |
+
img = img[:, txt.shape[1] :, ...]
|
215 |
+
|
216 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
217 |
+
return img
|
src/flux/sampling.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from .model import Flux
|
9 |
+
from .modules.conditioner import HFEmbedder
|
10 |
+
|
11 |
+
|
12 |
+
def get_noise(
|
13 |
+
num_samples: int,
|
14 |
+
height: int,
|
15 |
+
width: int,
|
16 |
+
device: torch.device,
|
17 |
+
dtype: torch.dtype,
|
18 |
+
seed: int,
|
19 |
+
):
|
20 |
+
return torch.randn(
|
21 |
+
num_samples,
|
22 |
+
16,
|
23 |
+
# allow for packing
|
24 |
+
2 * math.ceil(height / 16),
|
25 |
+
2 * math.ceil(width / 16),
|
26 |
+
device=device,
|
27 |
+
dtype=dtype,
|
28 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
33 |
+
bs, c, h, w = img.shape
|
34 |
+
if bs == 1 and not isinstance(prompt, str):
|
35 |
+
bs = len(prompt)
|
36 |
+
|
37 |
+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
38 |
+
if img.shape[0] == 1 and bs > 1:
|
39 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
40 |
+
|
41 |
+
img_ids = torch.zeros(h // 2, w // 2, 3)
|
42 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
43 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
44 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
45 |
+
|
46 |
+
if isinstance(prompt, str):
|
47 |
+
prompt = [prompt]
|
48 |
+
txt = t5(prompt)
|
49 |
+
if txt.shape[0] == 1 and bs > 1:
|
50 |
+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
51 |
+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
52 |
+
|
53 |
+
vec = clip(prompt)
|
54 |
+
if vec.shape[0] == 1 and bs > 1:
|
55 |
+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
56 |
+
|
57 |
+
return {
|
58 |
+
"img": img,
|
59 |
+
"img_ids": img_ids.to(img.device),
|
60 |
+
"txt": txt.to(img.device),
|
61 |
+
"txt_ids": txt_ids.to(img.device),
|
62 |
+
"vec": vec.to(img.device),
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
67 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
68 |
+
|
69 |
+
|
70 |
+
def get_lin_function(
|
71 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
72 |
+
) -> Callable[[float], float]:
|
73 |
+
m = (y2 - y1) / (x2 - x1)
|
74 |
+
b = y1 - m * x1
|
75 |
+
return lambda x: m * x + b
|
76 |
+
|
77 |
+
|
78 |
+
def get_schedule(
|
79 |
+
num_steps: int,
|
80 |
+
image_seq_len: int,
|
81 |
+
base_shift: float = 0.5,
|
82 |
+
max_shift: float = 1.15,
|
83 |
+
shift: bool = True,
|
84 |
+
) -> list[float]:
|
85 |
+
# extra step for zero
|
86 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
87 |
+
|
88 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
89 |
+
if shift:
|
90 |
+
# eastimate mu based on linear estimation between two points
|
91 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
92 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
93 |
+
|
94 |
+
return timesteps.tolist()
|
95 |
+
|
96 |
+
|
97 |
+
def denoise(
|
98 |
+
model: Flux,
|
99 |
+
# model input
|
100 |
+
img: Tensor,
|
101 |
+
img_ids: Tensor,
|
102 |
+
txt: Tensor,
|
103 |
+
txt_ids: Tensor,
|
104 |
+
vec: Tensor,
|
105 |
+
# sampling parameters
|
106 |
+
timesteps: list[float],
|
107 |
+
guidance: float = 4.0,
|
108 |
+
use_gs=False,
|
109 |
+
gs=4,
|
110 |
+
):
|
111 |
+
# this is ignored for schnell
|
112 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
113 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
114 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
115 |
+
pred = model(
|
116 |
+
img=img,
|
117 |
+
img_ids=img_ids,
|
118 |
+
txt=txt,
|
119 |
+
txt_ids=txt_ids,
|
120 |
+
y=vec,
|
121 |
+
timesteps=t_vec,
|
122 |
+
guidance=guidance_vec,
|
123 |
+
)
|
124 |
+
if use_gs:
|
125 |
+
pred_uncond, pred_text = pred.chunk(2)
|
126 |
+
pred = pred_uncond + gs * (pred_text - pred_uncond)
|
127 |
+
|
128 |
+
img = img + (t_prev - t_curr) * pred
|
129 |
+
#if use_gs:
|
130 |
+
# img = torch.cat([img] * 2)
|
131 |
+
|
132 |
+
return img
|
133 |
+
|
134 |
+
def denoise_controlnet(
|
135 |
+
model: Flux,
|
136 |
+
controlnet:None,
|
137 |
+
# model input
|
138 |
+
img: Tensor,
|
139 |
+
img_ids: Tensor,
|
140 |
+
txt: Tensor,
|
141 |
+
txt_ids: Tensor,
|
142 |
+
vec: Tensor,
|
143 |
+
controlnet_cond,
|
144 |
+
# sampling parameters
|
145 |
+
timesteps: list[float],
|
146 |
+
guidance: float = 4.0,
|
147 |
+
controlnet_gs=0.7,
|
148 |
+
):
|
149 |
+
# this is ignored for schnell
|
150 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
151 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
152 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
153 |
+
block_res_samples = controlnet(
|
154 |
+
img=img,
|
155 |
+
img_ids=img_ids,
|
156 |
+
controlnet_cond=controlnet_cond,
|
157 |
+
txt=txt,
|
158 |
+
txt_ids=txt_ids,
|
159 |
+
y=vec,
|
160 |
+
timesteps=t_vec,
|
161 |
+
guidance=guidance_vec,
|
162 |
+
)
|
163 |
+
pred = model(
|
164 |
+
img=img,
|
165 |
+
img_ids=img_ids,
|
166 |
+
txt=txt,
|
167 |
+
txt_ids=txt_ids,
|
168 |
+
y=vec,
|
169 |
+
timesteps=t_vec,
|
170 |
+
guidance=guidance_vec,
|
171 |
+
block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples]
|
172 |
+
)
|
173 |
+
|
174 |
+
img = img + (t_prev - t_curr) * pred
|
175 |
+
#if use_gs:
|
176 |
+
# img = torch.cat([img] * 2)
|
177 |
+
|
178 |
+
return img
|
179 |
+
|
180 |
+
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
181 |
+
return rearrange(
|
182 |
+
x,
|
183 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
184 |
+
h=math.ceil(height / 16),
|
185 |
+
w=math.ceil(width / 16),
|
186 |
+
ph=2,
|
187 |
+
pw=2,
|
188 |
+
)
|
src/flux/util.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from safetensors.torch import load_file as load_sft
|
8 |
+
|
9 |
+
from .model import Flux, FluxParams
|
10 |
+
from .controlnet import ControlNetFlux
|
11 |
+
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
|
12 |
+
from .modules.conditioner import HFEmbedder
|
13 |
+
|
14 |
+
from safetensors import safe_open
|
15 |
+
|
16 |
+
def load_safetensors(path):
|
17 |
+
tensors = {}
|
18 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
19 |
+
for key in f.keys():
|
20 |
+
tensors[key] = f.get_tensor(key)
|
21 |
+
return tensors
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class ModelSpec:
|
25 |
+
params: FluxParams
|
26 |
+
ae_params: AutoEncoderParams
|
27 |
+
ckpt_path: str | None
|
28 |
+
ae_path: str | None
|
29 |
+
repo_id: str | None
|
30 |
+
repo_flow: str | None
|
31 |
+
repo_ae: str | None
|
32 |
+
|
33 |
+
|
34 |
+
configs = {
|
35 |
+
"flux-dev": ModelSpec(
|
36 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
37 |
+
repo_flow="flux1-dev.safetensors",
|
38 |
+
repo_ae="ae.safetensors",
|
39 |
+
ckpt_path=os.getenv("FLUX_DEV"),
|
40 |
+
params=FluxParams(
|
41 |
+
in_channels=64,
|
42 |
+
vec_in_dim=768,
|
43 |
+
context_in_dim=4096,
|
44 |
+
hidden_size=3072,
|
45 |
+
mlp_ratio=4.0,
|
46 |
+
num_heads=24,
|
47 |
+
depth=19,
|
48 |
+
depth_single_blocks=38,
|
49 |
+
axes_dim=[16, 56, 56],
|
50 |
+
theta=10_000,
|
51 |
+
qkv_bias=True,
|
52 |
+
guidance_embed=True,
|
53 |
+
),
|
54 |
+
ae_path=os.getenv("AE"),
|
55 |
+
ae_params=AutoEncoderParams(
|
56 |
+
resolution=256,
|
57 |
+
in_channels=3,
|
58 |
+
ch=128,
|
59 |
+
out_ch=3,
|
60 |
+
ch_mult=[1, 2, 4, 4],
|
61 |
+
num_res_blocks=2,
|
62 |
+
z_channels=16,
|
63 |
+
scale_factor=0.3611,
|
64 |
+
shift_factor=0.1159,
|
65 |
+
),
|
66 |
+
),
|
67 |
+
"flux-schnell": ModelSpec(
|
68 |
+
repo_id="black-forest-labs/FLUX.1-schnell",
|
69 |
+
repo_flow="flux1-schnell.safetensors",
|
70 |
+
repo_ae="ae.safetensors",
|
71 |
+
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
72 |
+
params=FluxParams(
|
73 |
+
in_channels=64,
|
74 |
+
vec_in_dim=768,
|
75 |
+
context_in_dim=4096,
|
76 |
+
hidden_size=3072,
|
77 |
+
mlp_ratio=4.0,
|
78 |
+
num_heads=24,
|
79 |
+
depth=19,
|
80 |
+
depth_single_blocks=38,
|
81 |
+
axes_dim=[16, 56, 56],
|
82 |
+
theta=10_000,
|
83 |
+
qkv_bias=True,
|
84 |
+
guidance_embed=False,
|
85 |
+
),
|
86 |
+
ae_path=os.getenv("AE"),
|
87 |
+
ae_params=AutoEncoderParams(
|
88 |
+
resolution=256,
|
89 |
+
in_channels=3,
|
90 |
+
ch=128,
|
91 |
+
out_ch=3,
|
92 |
+
ch_mult=[1, 2, 4, 4],
|
93 |
+
num_res_blocks=2,
|
94 |
+
z_channels=16,
|
95 |
+
scale_factor=0.3611,
|
96 |
+
shift_factor=0.1159,
|
97 |
+
),
|
98 |
+
),
|
99 |
+
}
|
100 |
+
|
101 |
+
|
102 |
+
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
103 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
104 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
105 |
+
print("\n" + "-" * 79 + "\n")
|
106 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
107 |
+
elif len(missing) > 0:
|
108 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
109 |
+
elif len(unexpected) > 0:
|
110 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
111 |
+
|
112 |
+
|
113 |
+
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
114 |
+
# Loading Flux
|
115 |
+
print("Init model")
|
116 |
+
ckpt_path = configs[name].ckpt_path
|
117 |
+
if (
|
118 |
+
ckpt_path is None
|
119 |
+
and configs[name].repo_id is not None
|
120 |
+
and configs[name].repo_flow is not None
|
121 |
+
and hf_download
|
122 |
+
):
|
123 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
124 |
+
|
125 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
126 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
127 |
+
|
128 |
+
if ckpt_path is not None:
|
129 |
+
print("Loading checkpoint")
|
130 |
+
# load_sft doesn't support torch.device
|
131 |
+
sd = load_sft(ckpt_path, device=str(device))
|
132 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
133 |
+
print_load_warning(missing, unexpected)
|
134 |
+
return model
|
135 |
+
|
136 |
+
def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
137 |
+
# Loading Flux
|
138 |
+
print("Init model")
|
139 |
+
ckpt_path = configs[name].ckpt_path
|
140 |
+
if (
|
141 |
+
ckpt_path is None
|
142 |
+
and configs[name].repo_id is not None
|
143 |
+
and configs[name].repo_flow is not None
|
144 |
+
and hf_download
|
145 |
+
):
|
146 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
|
147 |
+
|
148 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
149 |
+
model = Flux(configs[name].params)
|
150 |
+
|
151 |
+
if ckpt_path is not None:
|
152 |
+
print("Loading checkpoint")
|
153 |
+
# load_sft doesn't support torch.device
|
154 |
+
sd = load_sft(ckpt_path, device=str(device))
|
155 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
156 |
+
print_load_warning(missing, unexpected)
|
157 |
+
return model
|
158 |
+
|
159 |
+
def load_controlnet(name, device, transformer=None):
|
160 |
+
with torch.device(device):
|
161 |
+
controlnet = ControlNetFlux(configs[name].params)
|
162 |
+
if transformer is not None:
|
163 |
+
controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
164 |
+
return controlnet
|
165 |
+
|
166 |
+
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
167 |
+
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
168 |
+
return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
169 |
+
|
170 |
+
|
171 |
+
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
172 |
+
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
|
173 |
+
|
174 |
+
|
175 |
+
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
176 |
+
ckpt_path = configs[name].ae_path
|
177 |
+
if (
|
178 |
+
ckpt_path is None
|
179 |
+
and configs[name].repo_id is not None
|
180 |
+
and configs[name].repo_ae is not None
|
181 |
+
and hf_download
|
182 |
+
):
|
183 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
|
184 |
+
|
185 |
+
# Loading the autoencoder
|
186 |
+
print("Init AE")
|
187 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
188 |
+
ae = AutoEncoder(configs[name].ae_params)
|
189 |
+
|
190 |
+
if ckpt_path is not None:
|
191 |
+
sd = load_sft(ckpt_path, device=str(device))
|
192 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
193 |
+
print_load_warning(missing, unexpected)
|
194 |
+
return ae
|
195 |
+
|
196 |
+
|
197 |
+
class WatermarkEmbedder:
|
198 |
+
def __init__(self, watermark):
|
199 |
+
self.watermark = watermark
|
200 |
+
self.num_bits = len(WATERMARK_BITS)
|
201 |
+
self.encoder = WatermarkEncoder()
|
202 |
+
self.encoder.set_watermark("bits", self.watermark)
|
203 |
+
|
204 |
+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
205 |
+
"""
|
206 |
+
Adds a predefined watermark to the input image
|
207 |
+
|
208 |
+
Args:
|
209 |
+
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
same as input but watermarked
|
213 |
+
"""
|
214 |
+
image = 0.5 * image + 0.5
|
215 |
+
squeeze = len(image.shape) == 4
|
216 |
+
if squeeze:
|
217 |
+
image = image[None, ...]
|
218 |
+
n = image.shape[0]
|
219 |
+
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
220 |
+
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
221 |
+
# watermarking libary expects input as cv2 BGR format
|
222 |
+
for k in range(image_np.shape[0]):
|
223 |
+
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
224 |
+
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
225 |
+
image.device
|
226 |
+
)
|
227 |
+
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
228 |
+
if squeeze:
|
229 |
+
image = image[0]
|
230 |
+
image = 2 * image - 1
|
231 |
+
return image
|
232 |
+
|
233 |
+
|
234 |
+
# A fixed 48-bit message that was choosen at random
|
235 |
+
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
236 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
237 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|