VikramSingh178 commited on
Commit
e27eecd
1 Parent(s): 441421d
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+
4
+
5
+ # Add a title
6
+ st.set_page_config(page_title="Select Diagnosis", layout="centered")
7
+
8
+
9
+
10
+ st.title("Medical Diagnosis App")
11
+
12
+ st.markdown("")
13
+ st.markdown("<li> Currently Brain Tumors , Xrays and Skin Leison Analysis are ready for diagnosis </li>"
14
+ "<li>The Models also explain what area in the images is the cause of diagnosis </li>"
15
+ "<li>Currently the models are trained on a small dataset and will be trained on a larger dataset in the future</li>"
16
+ '<li> The Application also provides generated information on how to diagnose the disease and what should the patient do in that case</li>'
17
+ ,unsafe_allow_html=True)
18
+
19
+ with st.sidebar.container():
20
+ image = Image.open("/Users/vikram/Downloads/Meditechlogo.png")
21
+ st.image(image, caption='Meditech',use_column_width=True)
22
+
23
+
24
+
25
+
brain_labels.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "0":"Glinomia",
3
+ "1": "Meningomia",
4
+ "2":"notumar",
5
+ "3": "pituary"
6
+ }
labels.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "0":"Covid19",
3
+ "1": "Normal",
4
+ "2":"Pneumonia",
5
+ "3": "Tuberculosis"
6
+ }
models/brain_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52a7a2393337e2dedca977e4729df96b2f55579c58cbd2263922d0f8752fd866
3
+ size 79728020
models/eye_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e83d4f41fd93bd35208cb9a1643158ee33b781ea111105fe614e2738edb3fe7
3
+ size 27238762
models/timm_skin_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5163d00507a13ccd735162fd43dd35b16cbb901bc06d407a4deb1cef194a0e2
3
+ size 16408803
models/timm_xray_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe8a9cda4e25a731216bb6503f36757e9ebde7251f2c1f34c1aa3489707419c4
3
+ size 93909024
pages/Brain.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch.nn as nn
4
+ import timm
5
+ import torch
6
+ import torchmetrics
7
+ from torchmetrics import F1Score,Recall,Accuracy
8
+ import torch.optim.lr_scheduler as lr_scheduler
9
+ import torchvision.models as models
10
+ import lightning.pytorch as pl
11
+ import torchvision
12
+ from lightning.pytorch.loggers import WandbLogger
13
+ import shap
14
+ import matplotlib.pyplot as plt
15
+ import json
16
+ from transformers import pipeline, set_seed
17
+ from transformers import BioGptTokenizer, BioGptForCausalLM
18
+ text_model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
19
+ tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
20
+ labels_path = '/Users/vikram/Python/Medical Diagnosis App/brain_labels.json'
21
+ from captum.attr import DeepLift , visualization
22
+
23
+ with open(labels_path) as json_data:
24
+ idx_to_labels = json.load(json_data)
25
+
26
+
27
+
28
+ class FineTuneModel(pl.LightningModule):
29
+ def __init__(self, model_name, num_classes, learning_rate, dropout_rate,beta1,beta2,eps):
30
+ super().__init__()
31
+ self.model_name = model_name
32
+ self.num_classes = num_classes
33
+ self.learning_rate = learning_rate
34
+ self.beta1 = beta1
35
+ self.beta2 = beta2
36
+ self.eps = eps
37
+ self.dropout_rate = dropout_rate
38
+ self.model = timm.create_model(self.model_name, pretrained=True,num_classes=self.num_classes)
39
+ self.loss_fn = nn.CrossEntropyLoss()
40
+ self.f1 = F1Score(task='multiclass', num_classes=self.num_classes)
41
+ self.recall = Recall(task='multiclass', num_classes=self.num_classes)
42
+ self.accuracy = Accuracy(task='multiclass', num_classes=self.num_classes)
43
+
44
+ #for param in self.model.parameters():
45
+ #param.requires_grad = True
46
+ #self.model.classifier= nn.Sequential(nn.Dropout(p=self.dropout_rate),nn.Linear(self.model.classifier.in_features, self.num_classes))
47
+ #self.model.classifier.requires_grad = True
48
+
49
+
50
+ def forward(self, x):
51
+ return self.model(x)
52
+
53
+ def training_step(self, batch, batch_idx):
54
+ x, y = batch
55
+ y_hat = self.model(x)
56
+ loss = self.loss_fn(y_hat, y)
57
+ acc = self.accuracy(y_hat.argmax(dim=1),y)
58
+ f1 = self.f1(y_hat.argmax(dim=1),y)
59
+ recall = self.recall(y_hat.argmax(dim=1),y)
60
+ self.log('train_loss', loss,on_step=False,on_epoch=True)
61
+ self.log('train_acc', acc,on_step=False,on_epoch = True)
62
+ self.log('train_f1',f1,on_step=False,on_epoch=True)
63
+ self.log('train_recall',recall,on_step=False,on_epoch=True)
64
+ return loss
65
+
66
+ def validation_step(self, batch, batch_idx):
67
+ x, y = batch
68
+ y_hat = self.model(x)
69
+ loss = self.loss_fn(y_hat, y)
70
+ acc = self.accuracy(y_hat.argmax(dim=1),y)
71
+ f1 = self.f1(y_hat.argmax(dim=1),y)
72
+ recall = self.recall(y_hat.argmax(dim=1),y)
73
+ self.log('val_loss', loss,on_step=False,on_epoch=True)
74
+ self.log('val_acc', acc,on_step=False,on_epoch=True)
75
+ self.log('val_f1',f1,on_step=False,on_epoch=True)
76
+ self.log('val_recall',recall,on_step=False,on_epoch=True)
77
+
78
+
79
+ def configure_optimizers(self):
80
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate,betas=(self.beta1,self.beta2),eps=self.eps)
81
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
82
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
83
+
84
+
85
+ #load model
86
+
87
+
88
+
89
+
90
+
91
+ st.markdown("<h1 style='text-align: center; '>Brain Tumor Diagnosis</h1>",unsafe_allow_html=True)
92
+
93
+
94
+
95
+
96
+ # Display a file uploader widget for the user to upload an image
97
+
98
+ uploaded_file = st.file_uploader("Choose an Brain MRI image file", type=["jpg", "jpeg", "png"])
99
+
100
+ # Load the uploaded image, or display emojis if no file was uploaded
101
+ with st.container():
102
+ if uploaded_file is not None:
103
+
104
+ image = Image.open(uploaded_file)
105
+ st.image(image, caption='Diagnosis', use_column_width=True)
106
+ model = timm.create_model(model_name='efficientnet_b1', pretrained=True,num_classes=4)
107
+ data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
108
+ transform = timm.data.create_transform(**data_cfg)
109
+ model_transforms = torchvision.transforms.Compose([transform])
110
+ transformed_image = model_transforms(image)
111
+ brain_model = torch.load('models/brain_model.pth')
112
+
113
+ brain_model.eval()
114
+ with torch.inference_mode():
115
+ with st.progress(100):
116
+
117
+ #class_names = ['Glinomia','Meningomia','notumar','pituary']
118
+ prediction = torch.nn.functional.softmax(brain_model(transformed_image.unsqueeze(dim=0))[0], dim=0)
119
+ prediction_score, pred_label_idx = torch.topk(prediction, 1)
120
+ pred_label_idx.squeeze_()
121
+ predicted_label = idx_to_labels[str(pred_label_idx.item())]
122
+ st.write( f'Predicted Label: {predicted_label}')
123
+ if st.button('Know More'):
124
+ generator = pipeline("text-generation",model=text_model,tokenizer=tokenizer)
125
+ input_text = f"Patient has {predicted_label} and is advised to take the following medicines:"
126
+ with st.spinner('Generating Text'):
127
+ generator(input_text, max_length=300, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)
128
+ st.markdown(generator(input_text, max_length=300, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)[0]['generated_text'])
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+ else:
143
+ st.success("Please upload an image file 🧠")
144
+
145
+
146
+
147
+
148
+
149
+ ## Model Explainibilty Dashboard using Captum
150
+
151
+
152
+
153
+
154
+
155
+
156
+
pages/Chest.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch.nn as nn
4
+ import timm
5
+ import torch
6
+ import time
7
+ import torchmetrics
8
+ from torchmetrics import F1Score,Recall,Accuracy
9
+ import torch.optim.lr_scheduler as lr_scheduler
10
+ import torchvision.models as models
11
+ import lightning.pytorch as pl
12
+ import torchvision
13
+ from lightning.pytorch.loggers import WandbLogger
14
+ import captum
15
+ import matplotlib.pyplot as plt
16
+ import json
17
+ from transformers import pipeline, set_seed
18
+ from transformers import BioGptTokenizer, BioGptForCausalLM
19
+ text_model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
20
+ tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
21
+ labels_path = '/Users/vikram/Python/Medical Diagnosis App/labels.json'
22
+
23
+
24
+ with open(labels_path) as json_data:
25
+ idx_to_labels = json.load(json_data)
26
+
27
+
28
+
29
+ class FineTuneModel(pl.LightningModule):
30
+ def __init__(self, model_name, num_classes, learning_rate, dropout_rate,beta1,beta2,eps):
31
+ super().__init__()
32
+ self.model_name = model_name
33
+ self.num_classes = num_classes
34
+ self.learning_rate = learning_rate
35
+ self.beta1 = beta1
36
+ self.beta2 = beta2
37
+ self.eps = eps
38
+ self.dropout_rate = dropout_rate
39
+ self.model = timm.create_model(self.model_name, pretrained=True,num_classes=self.num_classes)
40
+ self.loss_fn = nn.CrossEntropyLoss()
41
+ self.f1 = F1Score(task='multiclass', num_classes=self.num_classes)
42
+ self.recall = Recall(task='multiclass', num_classes=self.num_classes)
43
+ self.accuracy = Accuracy(task='multiclass', num_classes=self.num_classes)
44
+
45
+ #for param in self.model.parameters():
46
+ #param.requires_grad = True
47
+ #self.model.classifier= nn.Sequential(nn.Dropout(p=self.dropout_rate),nn.Linear(self.model.classifier.in_features, self.num_classes))
48
+ #self.model.classifier.requires_grad = True
49
+
50
+
51
+ def forward(self, x):
52
+ return self.model(x)
53
+
54
+ def training_step(self, batch, batch_idx):
55
+ x, y = batch
56
+ y_hat = self.model(x)
57
+ loss = self.loss_fn(y_hat, y)
58
+ acc = self.accuracy(y_hat.argmax(dim=1),y)
59
+ f1 = self.f1(y_hat.argmax(dim=1),y)
60
+ recall = self.recall(y_hat.argmax(dim=1),y)
61
+ self.log('train_loss', loss,on_step=False,on_epoch=True)
62
+ self.log('train_acc', acc,on_step=False,on_epoch = True)
63
+ self.log('train_f1',f1,on_step=False,on_epoch=True)
64
+ self.log('train_recall',recall,on_step=False,on_epoch=True)
65
+ return loss
66
+
67
+ def validation_step(self, batch, batch_idx):
68
+ x, y = batch
69
+ y_hat = self.model(x)
70
+ loss = self.loss_fn(y_hat, y)
71
+ acc = self.accuracy(y_hat.argmax(dim=1),y)
72
+ f1 = self.f1(y_hat.argmax(dim=1),y)
73
+ recall = self.recall(y_hat.argmax(dim=1),y)
74
+ self.log('val_loss', loss,on_step=False,on_epoch=True)
75
+ self.log('val_acc', acc,on_step=False,on_epoch=True)
76
+ self.log('val_f1',f1,on_step=False,on_epoch=True)
77
+ self.log('val_recall',recall,on_step=False,on_epoch=True)
78
+
79
+
80
+ def configure_optimizers(self):
81
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate,betas=(self.beta1,self.beta2),eps=self.eps)
82
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
83
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
84
+
85
+
86
+ #load model
87
+
88
+
89
+
90
+
91
+
92
+ st.markdown("<h1 style='text-align: center; '>Chest Xray Diagnosis</h1>",unsafe_allow_html=True)
93
+
94
+
95
+
96
+
97
+ # Display a file uploader widget for the user to upload an image
98
+ uploaded_file = st.file_uploader("Choose an Chest XRay Image file", type=["jpg", "jpeg", "png"])
99
+
100
+ # Load the uploaded image, or display emojis if no file was uploaded
101
+ if uploaded_file is not None:
102
+
103
+ image = Image.open(uploaded_file)
104
+ st.image(image, caption='Diagnosis',width=224, use_column_width=True)
105
+ model = timm.create_model(model_name='efficientnet_b2', pretrained=True,num_classes=4)
106
+ data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
107
+ transform = timm.data.create_transform(**data_cfg)
108
+ model_transforms = torchvision.transforms.Compose([transform])
109
+ transformed_image = model_transforms(image)
110
+ xray_model = torch.load('models/timm_xray_model.pth')
111
+
112
+ xray_model.eval()
113
+
114
+
115
+
116
+ with torch.inference_mode():
117
+ with st.progress(100):
118
+
119
+ prediction = torch.nn.functional.softmax(xray_model(transformed_image.unsqueeze(dim=0))[0], dim=0)
120
+ prediction_score, pred_label_idx = torch.topk(prediction, 1)
121
+ pred_label_idx.squeeze_()
122
+ predicted_label = idx_to_labels[str(pred_label_idx.item())]
123
+ st.write( f'Predicted Label: {predicted_label}')
124
+ if st.button('Know More'):
125
+ generator = pipeline("text-generation",model=text_model,tokenizer=tokenizer)
126
+ input_text = f"Patient has {predicted_label} and is advised to take the following medicines:"
127
+ with st.spinner('Generating Text'):
128
+ generator(input_text, max_length=300, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)
129
+ st.markdown(generator(input_text, max_length=300, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)[0]['generated_text'])
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+ else:
142
+ st.success("Please upload an image file ⚕️")
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
pages/Model Dashboard.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.components.v1.iframe(src = 'https://api.wandb.ai/links/vikramxd/nw5ru81j',width = 1000, height = 800,scrolling = True)
4
+
5
+
6
+
7
+
pages/Skin.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch.nn as nn
4
+ import timm
5
+ import torch
6
+ import torchmetrics
7
+ from torchmetrics import F1Score,Recall,Accuracy
8
+ import torch.optim.lr_scheduler as lr_scheduler
9
+ import torchvision.models as models
10
+ import lightning.pytorch as pl
11
+ import torchvision
12
+ from lightning.pytorch.loggers import WandbLogger
13
+ import shap
14
+ import matplotlib.pyplot as plt
15
+ import json
16
+ from transformers import pipeline, set_seed
17
+ from transformers import BioGptTokenizer, BioGptForCausalLM
18
+ text_model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
19
+ tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
20
+ labels_path = '/Users/vikram/Python/Medical Diagnosis App/skin_labels.json'
21
+ from captum.attr import DeepLift , visualization
22
+
23
+ with open(labels_path) as json_data:
24
+ idx_to_labels = json.load(json_data)
25
+
26
+
27
+
28
+ class FineTuneModel(pl.LightningModule):
29
+ def __init__(self, model_name, num_classes, learning_rate, dropout_rate,beta1,beta2,eps):
30
+ super().__init__()
31
+ self.model_name = model_name
32
+ self.num_classes = num_classes
33
+ self.learning_rate = learning_rate
34
+ self.beta1 = beta1
35
+ self.beta2 = beta2
36
+ self.eps = eps
37
+ self.dropout_rate = dropout_rate
38
+ self.model = timm.create_model(self.model_name, pretrained=True,num_classes=self.num_classes)
39
+ self.loss_fn = nn.CrossEntropyLoss()
40
+ self.f1 = F1Score(task='multiclass', num_classes=self.num_classes)
41
+ self.recall = Recall(task='multiclass', num_classes=self.num_classes)
42
+ self.accuracy = Accuracy(task='multiclass', num_classes=self.num_classes)
43
+
44
+ #for param in self.model.parameters():
45
+ #param.requires_grad = True
46
+ #self.model.classifier= nn.Sequential(nn.Dropout(p=self.dropout_rate),nn.Linear(self.model.classifier.in_features, self.num_classes))
47
+ #self.model.classifier.requires_grad = True
48
+
49
+
50
+ def forward(self, x):
51
+ return self.model(x)
52
+
53
+ def training_step(self, batch, batch_idx):
54
+ x, y = batch
55
+ y_hat = self.model(x)
56
+ loss = self.loss_fn(y_hat, y)
57
+ acc = self.accuracy(y_hat.argmax(dim=1),y)
58
+ f1 = self.f1(y_hat.argmax(dim=1),y)
59
+ recall = self.recall(y_hat.argmax(dim=1),y)
60
+ self.log('train_loss', loss,on_step=False,on_epoch=True)
61
+ self.log('train_acc', acc,on_step=False,on_epoch = True)
62
+ self.log('train_f1',f1,on_step=False,on_epoch=True)
63
+ self.log('train_recall',recall,on_step=False,on_epoch=True)
64
+ return loss
65
+
66
+ def validation_step(self, batch, batch_idx):
67
+ x, y = batch
68
+ y_hat = self.model(x)
69
+ loss = self.loss_fn(y_hat, y)
70
+ acc = self.accuracy(y_hat.argmax(dim=1),y)
71
+ f1 = self.f1(y_hat.argmax(dim=1),y)
72
+ recall = self.recall(y_hat.argmax(dim=1),y)
73
+ self.log('val_loss', loss,on_step=False,on_epoch=True)
74
+ self.log('val_acc', acc,on_step=False,on_epoch=True)
75
+ self.log('val_f1',f1,on_step=False,on_epoch=True)
76
+ self.log('val_recall',recall,on_step=False,on_epoch=True)
77
+
78
+
79
+ def configure_optimizers(self):
80
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate,betas=(self.beta1,self.beta2),eps=self.eps)
81
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
82
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
83
+
84
+
85
+ #load model
86
+
87
+
88
+
89
+
90
+
91
+ st.markdown("<h1 style='text-align: center; '>Skin Leision Diagnosis</h1>",unsafe_allow_html=True)
92
+
93
+
94
+
95
+
96
+ # Display a file uploader widget for the user to upload an image
97
+
98
+ uploaded_file = st.file_uploader("Choose an Skin image file", type=["jpg", "jpeg", "png"])
99
+
100
+ # Load the uploaded image, or display emojis if no file was uploaded
101
+ with st.container():
102
+ if uploaded_file is not None:
103
+
104
+ image = Image.open(uploaded_file)
105
+ st.image(image, caption='Diagnosis', use_column_width=True)
106
+ model = timm.create_model(model_name='efficientnet_b0', pretrained=True,num_classes=4)
107
+ data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
108
+ transform = timm.data.create_transform(**data_cfg)
109
+ model_transforms = torchvision.transforms.Compose([transform])
110
+ transformed_image = model_transforms(image)
111
+ brain_model = torch.load('models/timm_skin_model.pth')
112
+
113
+ brain_model.eval()
114
+ with torch.inference_mode():
115
+ with st.progress(100):
116
+
117
+ #class_names = ['Glinomia','Meningomia','notumar','pituary']
118
+ prediction = torch.nn.functional.softmax(brain_model(transformed_image.unsqueeze(dim=0))[0], dim=0)
119
+ prediction_score, pred_label_idx = torch.topk(prediction, 1)
120
+ pred_label_idx.squeeze_()
121
+ predicted_label = idx_to_labels[str(pred_label_idx.item())]
122
+ st.write( f'Predicted Label: {predicted_label}')
123
+ if st.button('Know More'):
124
+ generator = pipeline("text-generation",model=text_model,tokenizer=tokenizer)
125
+ input_text = f"Patient has {predicted_label} and is advised to take the following medicines:"
126
+ with st.spinner('Generating Text'):
127
+ generator(input_text, max_length=300, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)
128
+ st.markdown(generator(input_text, max_length=300, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)[0]['generated_text'])
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+ else:
143
+ st.success("Please upload an image file 🧠")
144
+
145
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ captum==0.6.0
2
+ lightning==2.0.1
3
+ matplotlib==3.6.3
4
+ Pillow==9.5.0
5
+ shap==0.41.0
6
+ streamlit==1.20.0
7
+ timm==0.6.13
8
+ torch==2.0.0
9
+ torchmetrics==0.11.4
10
+ torchvision==0.15.1
11
+ transformers==4.27.4
skin_labels.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0":"actinic keratoses and intraepithelial carcinoma",
3
+ "1": "basal cell carcinoma",
4
+ "2":"benign keratosis-like lesions",
5
+ "3": "dermatofibroma",
6
+ "4": "melanoma",
7
+ "5": "melanocytic nevi",
8
+ "6": "vascular lesions"
9
+
10
+ }