Fine Tuning on MR dataset
Hello, mr. baibai, I'm a researcher interested in Multi-Modal Medical AI.
Now I'm trying to use your M3D-CLIP work for fine tuning on my dataset, which is constructed by stroke dwi(adc)-mr images and its radiologic reports.
This is a sample for my input report:
Diffusion restriction at left PVWM and left posterior globus pallidus/putamen~corona radiata.
- with mild T2 signal change
--> acute infarction, likely.
When I just did inference on my 20 test samples, it performed 0.1 accuracy.
And after I train it on 220 samples and evaluate with previsou 20 test samples, I found that it has same accuracy with all different similarity scores.
So I think there is better strategy with process my data or parameters etc. And I kindly ask do you have any advice about my process.
This is what i did for my train.
After that, I calculate cos similarity on image feature and report feature to finding real report.
Thanks a lot,
Heeseong Eom
import os
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
import pandas as pd
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd,
Resized, NormalizeIntensityd, EnsureTyped
)
from sklearn.model_selection import train_test_split
from collections import Counter
import logging
import sys
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
attempt_num = input("Enter the attempt number: ")
best_model_dir = "/media/data4/hseum/py/rag_task/m3dclip_best_model"
os.makedirs(best_model_dir, exist_ok=True)
best_model_filename = f"best_model_attempt_{attempt_num}.pth"
best_model_path = os.path.join(best_model_dir, best_model_filename)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
internal_data_path = "/media/data4/hseum/stroke_mr/matched_image/BRMH"
internal_report_path = "/media/data4/hseum/report/matched_report/BRMH_mr_report.csv"
internal_df = pd.read_csv(internal_report_path)
all_labels = internal_df['ADP'].unique()
label_mapping = {label: idx for idx, label in enumerate(all_labels)}
print("# of Total Class: ", len(label_mapping))
print("Classes: ", all_labels, " to ", label_mapping)
internal_df['label'] = internal_df['ADP'].map(label_mapping)
def get_image_paths_and_labels(base_path, report_df):
image_paths = []
labels = []
summaries = []
for _, row in report_df.iterrows():
patient_id = str(row['PatientID']).strip().zfill(8)
date = str(row['Date'])
image_path = os.path.join(base_path, patient_id, date, 'mr.nii.gz')
if os.path.exists(image_path):
image_paths.append(image_path)
labels.append(row['label'])
summaries.append(row['summary'])
return image_paths, labels, summaries
internal_image_paths, internal_labels, internal_summaries = get_image_paths_and_labels(internal_data_path, internal_df)
print("Total internal image paths:", len(internal_image_paths))
class_counts = Counter(internal_df['label'])
total_count = sum(class_counts.values())
class_weights = {label: total_count / count for label, count in class_counts.items()}
internal_image_paths, internal_val_image_paths, internal_labels, internal_val_labels, internal_summaries, internal_val_summaries = train_test_split(
internal_image_paths, internal_labels, internal_summaries, test_size=20, random_state=42
)
def preprocess_image(image_path):
npy_path = image_path.replace('.nii.gz', '.npy')
if os.path.exists(npy_path):
return npy_path
transform = Compose([
LoadImaged(keys=["img"]),
EnsureChannelFirstd(keys=["img"]),
Orientationd(keys=["img"], axcodes="RAS"),
Spacingd(keys=["img"], pixdim=(0.9375, 0.9375, 4.0), mode='bilinear'),
Resized(keys=["img"], spatial_size=(256, 256, 32)),
NormalizeIntensityd(keys=["img"], nonzero=True, channel_wise=True),
EnsureTyped(keys=["img"], data_type='tensor', track_meta=False)
])
data = {"img": image_path}
img_transformed = transform(data)["img"]
img_transformed = img_transformed[1]
img_min = img_transformed.min()
img_max = img_transformed.max()
img_normalized = (img_transformed - img_min) / (img_max - img_min)
img_numpy = img_normalized.numpy().astype(np.float32)
np.save(npy_path, img_numpy)
return npy_path
train_files = [{"img": preprocess_image(img_path), "text": text}
for img_path, text in zip(internal_image_paths, internal_summaries)]
internal_val_files = [{"img": preprocess_image(img_path), "text": text}
for img_path, text in zip(internal_val_image_paths, internal_val_summaries)]
print(np.shape(np.load(train_files[0]["img"])))
print("Train files:", len(train_files))
print("Internal validation files:", len(internal_val_files))
tokenizer = AutoTokenizer.from_pretrained("GoodBaiBai88/M3D-CLIP", model_max_length=512, padding_side="right", use_fast=False)
model = AutoModel.from_pretrained("GoodBaiBai88/M3D-CLIP", trust_remote_code=True)
model = model.to(device=device)
model.gather_loss = False
class CustomDataset(Dataset):
def init(self, data_files):
self.data_files = data_files
def __len__(self):
return len(self.data_files)
def __getitem__(self, idx):
image = torch.from_numpy(np.load(self.data_files[idx]["img"]).astype(np.float32))[None, ...] # (32, 256, 256) -> (1, 32, 256, 256)
image = image[:, 1, :, :, :] # λ λ²μ§Έ μ±λ μ ν -> (3, 2, 256, 256, 32) -> (3, 1, 256, 256, 32)
text_tensor = tokenizer(self.data_files[idx]["text"], max_length=512, truncation=True, padding="max_length", return_tensors="pt")
input_id = text_tensor["input_ids"].squeeze(0)
attention_mask = text_tensor["attention_mask"].squeeze(0)
return image, input_id, attention_mask
train_dataset = CustomDataset(train_files)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)
EPOCHS = 100
best_loss = float('inf')
for epoch in range(EPOCHS):
epoch_loss = 0.0
for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}", leave=False):
optimizer.zero_grad()
images, input_ids, attention_masks = batch
images = images.to(device)
input_ids = input_ids.to(device)
attention_masks = attention_masks.to(device)
labels = torch.arange(len(images), dtype=torch.long, device=device)
ret = model(images, input_ids, attention_masks, labels)
total_loss = ret["loss"]
total_loss.backward()
optimizer.step()
epoch_loss += total_loss.item()
if epoch % 10==0:
print(f"epoch_loss: {epoch_loss} best_loss: {best_loss}")
if epoch_loss < best_loss:
best_loss = epoch_loss
if os.path.exists(best_model_path):
os.remove(best_model_path)
torch.save(model.state_dict(), best_model_path)
print("Training finished.")
Hi,
If you use another dataset, especially MRI data, not CT, it is best to fine-tune this M3D-CLIP model on your dataset, reducing domain bias.
In data preprocessing, we should make the shape 32x256x256 (DxHxW) and do normalization to 0-1.
Best regards,
BAI Fan