|
import torch |
|
from PIL import Image |
|
|
|
from models.transformer_sd3 import SD3Transformer2DModel |
|
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
model_path = 'stabilityai/stable-diffusion-3.5-large' |
|
ip_adapter_path = './ip-adapter.bin' |
|
image_encoder_path = "google/siglip-so400m-patch14-384" |
|
|
|
transformer = SD3Transformer2DModel.from_pretrained( |
|
model_path, subfolder="transformer", torch_dtype=torch.bfloat16 |
|
) |
|
|
|
pipe = StableDiffusion3Pipeline.from_pretrained( |
|
model_path, transformer=transformer, torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
pipe.init_ipadapter( |
|
ip_adapter_path=ip_adapter_path, |
|
image_encoder_path=image_encoder_path, |
|
nb_token=64, |
|
) |
|
|
|
ref_img = Image.open('./assets/1.jpg').convert('RGB') |
|
image = pipe( |
|
width=1024, |
|
height=1024, |
|
prompt='a cat', |
|
negative_prompt="lowres, low quality, worst quality", |
|
num_inference_steps=24, |
|
guidance_scale=5.0, |
|
generator=torch.Generator("cuda").manual_seed(42), |
|
clip_image=ref_img, |
|
ipadapter_scale=0.5, |
|
).images[0] |
|
image.save('./result.jpg') |