import gradio as gr import timm import torch import torch.nn as nn def change_num_input_channels(model, in_channels=1): """ Assumes number of input channels in model is 3. """ for i, m in enumerate(model.modules()): if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3: m.in_channels = in_channels # First, sum across channels W = m.weight.sum(1, keepdim=True) # Then, divide by number of channels W = W / in_channels # Then, repeat by number of channels size = [1] * W.ndim size[1] = in_channels W = W.repeat(size) m.weight = nn.Parameter(W) break return model class Net2D(nn.Module): def __init__(self, weights): super().__init__() self.backbone = timm.create_model("tf_efficientnetv2_s", pretrained=False, global_pool="", num_classes=0) self.backbone = change_num_input_channels(self.backbone, 2) self.pool_layer = nn.AdaptiveAvgPool2d(1) self.dropout = nn.Dropout(0.2) self.classifier = nn.Linear(1280, 1) self.load_state_dict(weights) def forward(self, x): x = self.backbone(x) x = self.pool_layer(x).view(x.size(0), -1) x = self.dropout(x) x = self.classifier(x) return x[:, 0] if x.size(1) == 1 else x class Ensemble(nn.Module): def __init__(self, model_list): super().__init__() self.model_list = nn.ModuleList(model_list) def forward(self, x): return torch.stack([model(x) for model in self.model_list]).mean(0) checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"] weights = [torch.load(ckpt, map_location=torch.device("cpu"))["state_dict"] for ckpt in checkpoints] weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights] models = [Net2D(wt) for wt in weights] ensemble = Ensemble(models).eval() def predict_bone_age(Radiograph, Sex): img = torch.from_numpy(Radiograph) img = img.unsqueeze(0).unsqueeze(0) img = img / img.max() img = img - 0.5 img = img * 2.0 if Sex == 1: img = torch.cat([img, torch.zeros_like(img) + 1], dim=1) else: img = torch.cat([img, torch.zeros_like(img) - 1], dim=1) with torch.no_grad(): bone_age = ensemble(img.float())[0].item() total_months = bone_age * 12 years = int(total_months // 12) months = total_months - years * 12 months = round(months) if months == 12: years += 1 months = 0 if years == 0: str_output = f"{months} months" if months != 1 else "1 month" else: months = round(months) if months == 0: str_output = f"{years} years" if years != 1 else "1 year" else: str_output = f"{years} years, {months} months" if months != 1 else f"{years} years, 1 month" return f"Estimated Bone Age: {str_output}" image = gr.Image(shape=(512, 512), image_mode="L") sex = gr.Radio(["Male", "Female"], type="index") label = gr.Label(show_label=True, label="Result") demo = gr.Interface( fn=predict_bone_age, inputs=[image, sex], outputs=label, ) if __name__ == "__main__": demo.launch()