onkarsus13's picture
Update README.md
80941ff verified
metadata
license: mit
language:
  - en
library_name: diffusers
pipeline_tag: image-to-image
tags:
  - medical

This repository contains a model specifically designed for synthetic data generation of 2D CT-scans, intended solely for research purposes. The base model we employed is Stable-Diffusion-Medium, which has been enhanced using ControlNet, a technique for exerting more precise control over the image generation process.

For pretraining, we utilized the Atlas Dataset from Johns Hopkins University. This dataset provided a comprehensive range of medical imaging data, crucial for the initial training phase of our model. Our aim with this project is to contribute to the medical imaging field by enabling more robust and versatile synthetic data generation.

Training Details
Image Size = (128, 128)
Batch_size = 8 x 28 x 12
Computes:
8 x Nvidia-A6000 48GB

Code for generation:

from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel, UniPCMultistepScheduler, LDMSuperResolutionPipeline
import torch
from PIL import Image
import numpy as np
from transformers import T5Tokenizer
import torch.nn as nn
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0"


class_dict_atlas = {
        0:(0, 0, 0),
        1:(255, 60, 0),
        2:(255, 60, 232),
        3:(134, 79, 117),
        4:(125, 0, 190),
        5:(117, 200, 191),
        6:(230, 91, 101),
        7:(255, 0, 155),
        8:(75, 205, 155),
        9:(100, 37, 200)
}

name_class_dict = {
        0:"background",
        1:"aorta",
        2:"kidney_left",
        3:"liver",
        4:"postcava",
        5:"stomach",
        6:"gall_bladder",
        7:"kidney_right",
        8:"pancreas",
        9:"spleen"
}

def rgb_to_onehot(rgb_arr, color_dict=class_dict_atlas):
    num_classes = len(color_dict)
    shape = rgb_arr.shape[:2]+(num_classes,)
    arr = np.zeros( shape, dtype=np.int8 )
    for i, cls in enumerate(color_dict):
        arr[:,:,i] = np.all(rgb_arr.reshape( (-1,3) ) == color_dict[i], axis=1).reshape(shape[:2])
    return arr



pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
    "onkarsus13/Semantic-Control-Stable-diffusion-3-M-Mask2CT-Atlas", torch_dtype=torch.float16, safety_checker=None,
        feature_extractor=None,
)


pipe.tokenizer_3 = T5Tokenizer.from_pretrained(
        "onkarsus13/Semantic-Control-Stable-diffusion-3-M-Mask2CT-Atlas",
        subfolder='tokenizer_3'
)

pipe.to('cuda')
pipe.enable_model_cpu_offload()


generator = torch.Generator(device="cuda").manual_seed(1)
images = Image.open("<Give mask image for semantic guidance>")
shape = images.size

npi = np.asarray(images.convert("RGB"))
npi = rgb_to_onehot(npi, ).argmax(-1)
unique_ids = np.unique(npi)

print('CT image containg '+" ".join([name_class_dict[i] for i in unique_ids]))

image = pipe(
    prompt='CT image containg '+" ".join([name_class_dict[i] for i in unique_ids]),
    control_image=images.convert('RGB'),
    height=128,
    width=128,
    num_inference_steps=50,
    generator=generator,
    controlnet_conditioning_scale=1.0,
).images[0]

image.resize(shape).save('result.png')