Spaces:
Runtime error
Runtime error
Fix
Browse files- app.py +1 -1
- flux/modules/conditioner.py +6 -4
app.py
CHANGED
@@ -266,5 +266,5 @@ def create_demo(model_name: str, device: str = "cuda", offload: bool = False):
|
|
266 |
# parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
|
267 |
# args = parser.parse_args()
|
268 |
|
269 |
-
demo = create_demo("flux-
|
270 |
demo.launch()
|
|
|
266 |
# parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
|
267 |
# args = parser.parse_args()
|
268 |
|
269 |
+
demo = create_demo("flux-schnell", None, False)
|
270 |
demo.launch()
|
flux/modules/conditioner.py
CHANGED
@@ -15,13 +15,15 @@ class HFEmbedder(nnx.Module):
|
|
15 |
if self.is_clip:
|
16 |
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
17 |
# self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
18 |
-
self.hf_module
|
19 |
else:
|
20 |
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
21 |
# self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
22 |
-
self.hf_module
|
23 |
-
|
24 |
-
|
|
|
|
|
25 |
|
26 |
def tokenize(self, text: list[str]) -> Tensor:
|
27 |
batch_encoding = self.tokenizer(
|
|
|
15 |
if self.is_clip:
|
16 |
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
17 |
# self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
18 |
+
self.hf_module, params = FlaxCLIPTextModel.from_pretrained(version, _do_init=False, **hf_kwargs)
|
19 |
else:
|
20 |
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
21 |
# self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
22 |
+
self.hf_module, params = FlaxT5EncoderModel.from_pretrained(version, _do_init=False,**hf_kwargs)
|
23 |
+
self.hf_module._is_initialized = True
|
24 |
+
import jax
|
25 |
+
self.hf_module.params = jax.tree_map(lambda x: jax.device_put(x, jax.devices("cuda")[0]), params)
|
26 |
+
# if dtype==jnp.bfloat16:
|
27 |
|
28 |
def tokenize(self, text: list[str]) -> Tensor:
|
29 |
batch_encoding = self.tokenizer(
|