Spaces:
Runtime error
Runtime error
chris1nexus
commited on
Commit
•
54660f7
1
Parent(s):
6c29b16
First commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +4 -4
- app.py +175 -0
- feature_extractor/.ipynb_checkpoints/weights_check-checkpoint.ipynb +0 -0
- feature_extractor/__init__.py +0 -0
- feature_extractor/__pycache__/__init__.cpython-38.pyc +0 -0
- feature_extractor/__pycache__/build_graph_utils.cpython-38.pyc +0 -0
- feature_extractor/__pycache__/build_graphs.cpython-38.pyc +0 -0
- feature_extractor/__pycache__/cl.cpython-38.pyc +0 -0
- feature_extractor/__pycache__/simclr.cpython-36.pyc +0 -0
- feature_extractor/__pycache__/simclr.cpython-38.pyc +0 -0
- feature_extractor/build_graph_utils.py +85 -0
- feature_extractor/build_graphs.py +114 -0
- feature_extractor/cl.py +83 -0
- feature_extractor/config.yaml +23 -0
- feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-36.pyc +0 -0
- feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-38.pyc +0 -0
- feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-36.pyc +0 -0
- feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-38.pyc +0 -0
- feature_extractor/data_aug/dataset_wrapper.py +93 -0
- feature_extractor/data_aug/gaussian_blur.py +26 -0
- feature_extractor/load_patches.py +37 -0
- feature_extractor/loss/__pycache__/nt_xent.cpython-36.pyc +0 -0
- feature_extractor/loss/__pycache__/nt_xent.cpython-38.pyc +0 -0
- feature_extractor/loss/nt_xent.py +65 -0
- feature_extractor/models/__init__.py +0 -0
- feature_extractor/models/__pycache__/__init__.cpython-38.pyc +0 -0
- feature_extractor/models/__pycache__/resnet_simclr.cpython-36.pyc +0 -0
- feature_extractor/models/__pycache__/resnet_simclr.cpython-38.pyc +0 -0
- feature_extractor/models/baseline_encoder.py +43 -0
- feature_extractor/models/resnet_simclr.py +37 -0
- feature_extractor/run.py +21 -0
- feature_extractor/simclr.py +165 -0
- feature_extractor/viewer.py +227 -0
- helper.py +104 -0
- main.py +169 -0
- metadata/label_map.pkl +3 -0
- models/.gitkeep +1 -0
- models/GraphTransformer.py +123 -0
- models/ViT.py +415 -0
- models/__init__.py +0 -0
- models/__pycache__/GraphTransformer.cpython-38.pyc +0 -0
- models/__pycache__/ViT.cpython-38.pyc +0 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/gcn.cpython-38.pyc +0 -0
- models/__pycache__/layers.cpython-38.pyc +0 -0
- models/__pycache__/weight_init.cpython-38.pyc +0 -0
- models/gcn.py +420 -0
- models/layers.py +280 -0
- models/weight_init.py +78 -0
- option.py +41 -0
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: AioMedica
|
3 |
+
emoji: 🏃
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: yellow
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import openslide
|
3 |
+
import os
|
4 |
+
from streamlit_option_menu import option_menu
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
@st.cache(suppress_st_warning=True)
|
9 |
+
def load_model():
|
10 |
+
from predict import Predictor
|
11 |
+
predictor = Predictor()
|
12 |
+
return predictor
|
13 |
+
|
14 |
+
@st.cache(suppress_st_warning=True)
|
15 |
+
def load_dependencies():
|
16 |
+
|
17 |
+
os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
|
18 |
+
os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
|
19 |
+
os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
|
20 |
+
|
21 |
+
|
22 |
+
def main():
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
# environment variables for the inference api
|
29 |
+
os.environ['DATA_DIR'] = 'queries'
|
30 |
+
os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches')
|
31 |
+
os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides')
|
32 |
+
os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots')
|
33 |
+
os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True)
|
34 |
+
|
35 |
+
|
36 |
+
# manually put the metadata in the metadata folder
|
37 |
+
os.environ['CLASS_METADATA'] ='metadata/label_map.pkl'
|
38 |
+
|
39 |
+
# manually put the desired weights in the weights folder
|
40 |
+
os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights'
|
41 |
+
os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth')
|
42 |
+
os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth')
|
43 |
+
|
44 |
+
|
45 |
+
#st.set_page_config(page_title="",layout='wide')
|
46 |
+
predictor = load_model()#Predictor()
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool."
|
53 |
+
CONTACT_TEXT = """
|
54 |
+
_Built by Christian Cancedda and LabLab lads with love_ ❤️
|
55 |
+
[![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus)
|
56 |
+
[![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda)
|
57 |
+
Star project repository:
|
58 |
+
[![GitHub stars](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus/inference-graph-transformer)
|
59 |
+
"""
|
60 |
+
VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window"
|
61 |
+
DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease"
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
with st.sidebar:
|
66 |
+
choice = option_menu("LastMinute - Diagnosis",
|
67 |
+
["About", "Visualize WSI slide", "Cancer Detection", "Contact"],
|
68 |
+
icons=['house', 'upload', 'activity', 'person lines fill'],
|
69 |
+
menu_icon="app-indicator", default_index=0,
|
70 |
+
styles={
|
71 |
+
# "container": {"padding": "5!important", "background-color": "#fafafa", },
|
72 |
+
"container": {"border-radius": ".0rem"},
|
73 |
+
# "icon": {"color": "orange", "font-size": "25px"},
|
74 |
+
# "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px",
|
75 |
+
# "--hover-color": "#eee"},
|
76 |
+
# "nav-link-selected": {"background-color": "#02ab21"},
|
77 |
+
}
|
78 |
+
)
|
79 |
+
st.sidebar.markdown(
|
80 |
+
"""
|
81 |
+
<style>
|
82 |
+
.aligncenter {
|
83 |
+
text-align: center;
|
84 |
+
}
|
85 |
+
</style>
|
86 |
+
<p style='text-align: center'>
|
87 |
+
<a href="https://github.com/Chris1nexus/inference-graph-transformer" target="_blank">Project Repository</a>
|
88 |
+
</p>
|
89 |
+
<p class="aligncenter">
|
90 |
+
|
91 |
+
<a href="https://github.com/Chris1nexus/inference-graph-transformer" target="_blank">
|
92 |
+
<img src="https://img.shields.io/github/stars/Chris1nexus/inference-graph-transformer?style=social"/>
|
93 |
+
</a>
|
94 |
+
</p>
|
95 |
+
|
96 |
+
<p class="aligncenter">
|
97 |
+
<a href="https://twitter.com/chris_cancedda" target="_blank">
|
98 |
+
<img src="https://img.shields.io/twitter/follow/chris_cancedda?style=social"/>
|
99 |
+
</a>
|
100 |
+
</p>
|
101 |
+
""",
|
102 |
+
unsafe_allow_html=True,
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
if choice == "About":
|
107 |
+
st.title(choice)
|
108 |
+
README = requests.get("https://raw.githubusercontent.com/Chris1nexus/inference-graph-transformer/master/README.md").text
|
109 |
+
README = str(README).replace('width="1200"','width="700"')
|
110 |
+
# st.title(choose)
|
111 |
+
st.markdown(README, unsafe_allow_html=True)
|
112 |
+
|
113 |
+
if choice == "Visualize WSI slide":
|
114 |
+
st.title(choice)
|
115 |
+
st.markdown(VISUALIZE_TEXT)
|
116 |
+
|
117 |
+
uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
|
118 |
+
if uploaded_file is not None:
|
119 |
+
ori = openslide.OpenSlide(uploaded_file.name)
|
120 |
+
width, height = ori.dimensions
|
121 |
+
|
122 |
+
REDUCTION_FACTOR = 20
|
123 |
+
w, h = int(width/512), int(height/512)
|
124 |
+
w_r, h_r = int(width/20), int(height/20)
|
125 |
+
resized_img = ori.get_thumbnail((w_r,h_r))
|
126 |
+
resized_img = resized_img.resize((w_r,h_r))
|
127 |
+
ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
|
128 |
+
#print('ratios ', ratio_w, ratio_h)
|
129 |
+
w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
|
130 |
+
st.image(resized_img, use_column_width='never')
|
131 |
+
|
132 |
+
if choice == "Cancer Detection":
|
133 |
+
state = dict()
|
134 |
+
|
135 |
+
st.title(choice)
|
136 |
+
st.markdown(DETECT_TEXT)
|
137 |
+
uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
|
138 |
+
st.markdown("Examples can be chosen at the [GDC Data repository](https://portal.gdc.cancer.gov/repository?facetTab=cases&filters=%7B%22op%22%3A%22and%22%2C%22content%22%3A%5B%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22cases.primary_site%22%2C%22value%22%3A%5B%22bronchus%20and%20lung%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22cases.project.program.name%22%2C%22value%22%3A%5B%22TCGA%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22cases.project.project_id%22%2C%22value%22%3A%5B%22TCGA-LUAD%22%2C%22TCGA-LUSC%22%5D%7D%7D%2C%7B%22op%22%3A%22in%22%2C%22content%22%3A%7B%22field%22%3A%22files.experimental_strategy%22%2C%22value%22%3A%5B%22Tissue%20Slide%22%5D%7D%7D%5D%7D)")
|
139 |
+
st.markdown("Alternatively, for simplicity few test cases are provided at the [drive link](https://drive.google.com/drive/folders/1u3SQa2dytZBHHh6eXTlMKY-pZGZ-pwkk?usp=share_link)")
|
140 |
+
|
141 |
+
|
142 |
+
if uploaded_file is not None:
|
143 |
+
# To read file as bytes:
|
144 |
+
#print(uploaded_file)
|
145 |
+
with open(os.path.join(uploaded_file.name),"wb") as f:
|
146 |
+
f.write(uploaded_file.getbuffer())
|
147 |
+
with st.spinner(text="Computation is running"):
|
148 |
+
predicted_class, viz_dict = predictor.predict(uploaded_file.name)
|
149 |
+
st.info('Computation completed.')
|
150 |
+
st.header(f'Predicted to be: {predicted_class}')
|
151 |
+
st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected')
|
152 |
+
state['cur'] = predicted_class
|
153 |
+
mapper = {'ORI': predicted_class, predicted_class:'ORI'}
|
154 |
+
readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' }
|
155 |
+
#def fn():
|
156 |
+
# st.image(viz_dict[mapper[state['cur']]], use_column_width='never', channels='BGR')
|
157 |
+
# state['cur'] = mapper[state['cur']]
|
158 |
+
# return
|
159 |
+
|
160 |
+
#st.button(f'See {readable_mapper[mapper[state["cur"]] ]}', on_click=fn )
|
161 |
+
#st.image(viz_dict[state['cur']], use_column_width='never', channels='BGR')
|
162 |
+
st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR'
|
163 |
+
# use_column_width='never',
|
164 |
+
)
|
165 |
+
|
166 |
+
|
167 |
+
if choice == "Contact":
|
168 |
+
st.title(choice)
|
169 |
+
st.markdown(CONTACT_TEXT)
|
170 |
+
|
171 |
+
if __name__ == '__main__':
|
172 |
+
#'''
|
173 |
+
load_dependencies()
|
174 |
+
#'''
|
175 |
+
main()
|
feature_extractor/.ipynb_checkpoints/weights_check-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
feature_extractor/__init__.py
ADDED
File without changes
|
feature_extractor/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (179 Bytes). View file
|
|
feature_extractor/__pycache__/build_graph_utils.cpython-38.pyc
ADDED
Binary file (3.48 kB). View file
|
|
feature_extractor/__pycache__/build_graphs.cpython-38.pyc
ADDED
Binary file (6.45 kB). View file
|
|
feature_extractor/__pycache__/cl.cpython-38.pyc
ADDED
Binary file (3.05 kB). View file
|
|
feature_extractor/__pycache__/simclr.cpython-36.pyc
ADDED
Binary file (4.38 kB). View file
|
|
feature_extractor/__pycache__/simclr.cpython-38.pyc
ADDED
Binary file (4.5 kB). View file
|
|
feature_extractor/build_graph_utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import torchvision.models as models
|
6 |
+
import torchvision.transforms.functional as VF
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
import sys, argparse, os, glob
|
10 |
+
import pandas as pd
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
from collections import OrderedDict
|
14 |
+
|
15 |
+
class ToPIL(object):
|
16 |
+
def __call__(self, sample):
|
17 |
+
img = sample
|
18 |
+
img = transforms.functional.to_pil_image(img)
|
19 |
+
return img
|
20 |
+
|
21 |
+
class BagDataset():
|
22 |
+
def __init__(self, csv_file, transform=None):
|
23 |
+
self.files_list = csv_file
|
24 |
+
self.transform = transform
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.files_list)
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
temp_path = self.files_list[idx]
|
29 |
+
img = os.path.join(temp_path)
|
30 |
+
img = Image.open(img)
|
31 |
+
img = img.resize((224, 224))
|
32 |
+
sample = {'input': img}
|
33 |
+
|
34 |
+
if self.transform:
|
35 |
+
sample = self.transform(sample)
|
36 |
+
return sample
|
37 |
+
|
38 |
+
class ToTensor(object):
|
39 |
+
def __call__(self, sample):
|
40 |
+
img = sample['input']
|
41 |
+
img = VF.to_tensor(img)
|
42 |
+
return {'input': img}
|
43 |
+
|
44 |
+
class Compose(object):
|
45 |
+
def __init__(self, transforms):
|
46 |
+
self.transforms = transforms
|
47 |
+
|
48 |
+
def __call__(self, img):
|
49 |
+
for t in self.transforms:
|
50 |
+
img = t(img)
|
51 |
+
return img
|
52 |
+
|
53 |
+
def save_coords(txt_file, csv_file_path):
|
54 |
+
for path in csv_file_path:
|
55 |
+
x, y = path.split('/')[-1].split('.')[0].split('_')
|
56 |
+
txt_file.writelines(str(x) + '\t' + str(y) + '\n')
|
57 |
+
txt_file.close()
|
58 |
+
|
59 |
+
def adj_matrix(csv_file_path, output, device='cpu'):
|
60 |
+
total = len(csv_file_path)
|
61 |
+
adj_s = np.zeros((total, total))
|
62 |
+
|
63 |
+
for i in range(total-1):
|
64 |
+
path_i = csv_file_path[i]
|
65 |
+
x_i, y_i = path_i.split('/')[-1].split('.')[0].split('_')
|
66 |
+
for j in range(i+1, total):
|
67 |
+
# sptial
|
68 |
+
path_j = csv_file_path[j]
|
69 |
+
x_j, y_j = path_j.split('/')[-1].split('.')[0].split('_')
|
70 |
+
if abs(int(x_i)-int(x_j)) <=1 and abs(int(y_i)-int(y_j)) <= 1:
|
71 |
+
adj_s[i][j] = 1
|
72 |
+
adj_s[j][i] = 1
|
73 |
+
|
74 |
+
adj_s = torch.from_numpy(adj_s)
|
75 |
+
adj_s = adj_s.to(device)
|
76 |
+
|
77 |
+
return adj_s
|
78 |
+
|
79 |
+
def bag_dataset(args, csv_file_path):
|
80 |
+
transformed_dataset = BagDataset(csv_file=csv_file_path,
|
81 |
+
transform=Compose([
|
82 |
+
ToTensor()
|
83 |
+
]))
|
84 |
+
dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False)
|
85 |
+
return dataloader, len(transformed_dataset)
|
feature_extractor/build_graphs.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from cl import IClassifier
|
3 |
+
from build_graph_utils import *
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import torchvision.models as models
|
8 |
+
import torchvision.transforms.functional as VF
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
import sys, argparse, os, glob
|
12 |
+
import pandas as pd
|
13 |
+
import numpy as np
|
14 |
+
from PIL import Image
|
15 |
+
from collections import OrderedDict
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def compute_feats(args, bags_list, i_classifier, device, save_path=None, whole_slide_path=None):
|
20 |
+
num_bags = len(bags_list)
|
21 |
+
Tensor = torch.FloatTensor
|
22 |
+
for i in range(0, num_bags):
|
23 |
+
feats_list = []
|
24 |
+
if args.magnification == '20x':
|
25 |
+
glob_path = os.path.join(bags_list[i], '*.jpeg')
|
26 |
+
csv_file_path = glob.glob(glob_path)
|
27 |
+
# line below was in the original version, commented due to errror with current version
|
28 |
+
#file_name = bags_list[i].split('/')[-3].split('_')[0]
|
29 |
+
|
30 |
+
file_name = glob_path.split('/')[-3].split('_')[0]
|
31 |
+
|
32 |
+
if args.magnification == '5x' or args.magnification == '10x':
|
33 |
+
csv_file_path = glob.glob(os.path.join(bags_list[i], '*.jpg'))
|
34 |
+
|
35 |
+
dataloader, bag_size = bag_dataset(args, csv_file_path)
|
36 |
+
print('{} files to be processed: {}'.format(len(csv_file_path), file_name))
|
37 |
+
|
38 |
+
if os.path.isdir(os.path.join(save_path, 'simclr_files', file_name)) or len(csv_file_path) < 1:
|
39 |
+
print('alreday exists')
|
40 |
+
continue
|
41 |
+
with torch.no_grad():
|
42 |
+
for iteration, batch in enumerate(dataloader):
|
43 |
+
patches = batch['input'].float().to(device)
|
44 |
+
feats, classes = i_classifier(patches)
|
45 |
+
#feats = feats.cpu().numpy()
|
46 |
+
feats_list.extend(feats)
|
47 |
+
|
48 |
+
os.makedirs(os.path.join(save_path, 'simclr_files', file_name), exist_ok=True)
|
49 |
+
|
50 |
+
txt_file = open(os.path.join(save_path, 'simclr_files', file_name, 'c_idx.txt'), "w+")
|
51 |
+
save_coords(txt_file, csv_file_path)
|
52 |
+
# save node features
|
53 |
+
output = torch.stack(feats_list, dim=0).to(device)
|
54 |
+
torch.save(output, os.path.join(save_path, 'simclr_files', file_name, 'features.pt'))
|
55 |
+
# save adjacent matrix
|
56 |
+
adj_s = adj_matrix(csv_file_path, output, device=device)
|
57 |
+
torch.save(adj_s, os.path.join(save_path, 'simclr_files', file_name, 'adj_s.pt'))
|
58 |
+
|
59 |
+
print('\r Computed: {}/{}'.format(i+1, num_bags))
|
60 |
+
|
61 |
+
|
62 |
+
def main():
|
63 |
+
parser = argparse.ArgumentParser(description='Compute TCGA features from SimCLR embedder')
|
64 |
+
parser.add_argument('--num_classes', default=2, type=int, help='Number of output classes')
|
65 |
+
parser.add_argument('--num_feats', default=512, type=int, help='Feature size')
|
66 |
+
parser.add_argument('--batch_size', default=128, type=int, help='Batch size of dataloader')
|
67 |
+
parser.add_argument('--num_workers', default=0, type=int, help='Number of threads for datalodaer')
|
68 |
+
parser.add_argument('--dataset', default=None, type=str, help='path to patches')
|
69 |
+
parser.add_argument('--backbone', default='resnet18', type=str, help='Embedder backbone')
|
70 |
+
parser.add_argument('--magnification', default='20x', type=str, help='Magnification to compute features')
|
71 |
+
parser.add_argument('--weights', default=None, type=str, help='path to the pretrained weights')
|
72 |
+
parser.add_argument('--output', default=None, type=str, help='path to the output graph folder')
|
73 |
+
args = parser.parse_args()
|
74 |
+
|
75 |
+
if args.backbone == 'resnet18':
|
76 |
+
resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d)
|
77 |
+
num_feats = 512
|
78 |
+
if args.backbone == 'resnet34':
|
79 |
+
resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d)
|
80 |
+
num_feats = 512
|
81 |
+
if args.backbone == 'resnet50':
|
82 |
+
resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d)
|
83 |
+
num_feats = 2048
|
84 |
+
if args.backbone == 'resnet101':
|
85 |
+
resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d)
|
86 |
+
num_feats = 2048
|
87 |
+
for param in resnet.parameters():
|
88 |
+
param.requires_grad = False
|
89 |
+
resnet.fc = nn.Identity()
|
90 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
91 |
+
print("Running on:", device)
|
92 |
+
i_classifier = IClassifier(resnet, num_feats, output_class=args.num_classes).to(device)
|
93 |
+
|
94 |
+
# load feature extractor
|
95 |
+
if args.weights is None:
|
96 |
+
print('No feature extractor')
|
97 |
+
return
|
98 |
+
state_dict_weights = torch.load(args.weights)
|
99 |
+
state_dict_init = i_classifier.state_dict()
|
100 |
+
new_state_dict = OrderedDict()
|
101 |
+
for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()):
|
102 |
+
if 'features' not in k:
|
103 |
+
continue
|
104 |
+
name = k_0
|
105 |
+
new_state_dict[name] = v
|
106 |
+
i_classifier.load_state_dict(new_state_dict, strict=False)
|
107 |
+
|
108 |
+
os.makedirs(args.output, exist_ok=True)
|
109 |
+
bags_list = glob.glob(args.dataset)
|
110 |
+
print(bags_list)
|
111 |
+
compute_feats(args, bags_list, i_classifier, device, args.output)
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
main()
|
feature_extractor/cl.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Variable
|
5 |
+
|
6 |
+
class FCLayer(nn.Module):
|
7 |
+
def __init__(self, in_size, out_size=1):
|
8 |
+
super(FCLayer, self).__init__()
|
9 |
+
self.fc = nn.Sequential(nn.Linear(in_size, out_size))
|
10 |
+
def forward(self, feats):
|
11 |
+
x = self.fc(feats)
|
12 |
+
return feats, x
|
13 |
+
|
14 |
+
class IClassifier(nn.Module):
|
15 |
+
def __init__(self, feature_extractor, feature_size, output_class):
|
16 |
+
super(IClassifier, self).__init__()
|
17 |
+
|
18 |
+
self.feature_extractor = feature_extractor
|
19 |
+
self.fc = nn.Linear(feature_size, output_class)
|
20 |
+
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
device = x.device
|
24 |
+
feats = self.feature_extractor(x) # N x K
|
25 |
+
c = self.fc(feats.view(feats.shape[0], -1)) # N x C
|
26 |
+
return feats.view(feats.shape[0], -1), c
|
27 |
+
|
28 |
+
class BClassifier(nn.Module):
|
29 |
+
def __init__(self, input_size, output_class, dropout_v=0.0): # K, L, N
|
30 |
+
super(BClassifier, self).__init__()
|
31 |
+
self.q = nn.Linear(input_size, 128)
|
32 |
+
self.v = nn.Sequential(
|
33 |
+
nn.Dropout(dropout_v),
|
34 |
+
nn.Linear(input_size, input_size)
|
35 |
+
)
|
36 |
+
|
37 |
+
### 1D convolutional layer that can handle multiple class (including binary)
|
38 |
+
self.fcc = nn.Conv1d(output_class, output_class, kernel_size=input_size)
|
39 |
+
|
40 |
+
def forward(self, feats, c): # N x K, N x C
|
41 |
+
device = feats.device
|
42 |
+
V = self.v(feats) # N x V, unsorted
|
43 |
+
Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted
|
44 |
+
|
45 |
+
# handle multiple classes without for loop
|
46 |
+
_, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C
|
47 |
+
m_feats = torch.index_select(feats, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K
|
48 |
+
q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q
|
49 |
+
A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores
|
50 |
+
A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C,
|
51 |
+
B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V
|
52 |
+
|
53 |
+
|
54 |
+
# for i in range(c.shape[1]):
|
55 |
+
# _, indices = torch.sort(c[:, i], 0, True)
|
56 |
+
# feats = torch.index_select(feats, 0, indices) # N x K, sorted
|
57 |
+
# q_max = self.q(feats[0].view(1, -1)) # 1 x 1 x Q
|
58 |
+
# temp = torch.mm(Q, q_max.view(-1, 1)) / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device))
|
59 |
+
# if i == 0:
|
60 |
+
# A = F.softmax(temp, 0) # N x 1
|
61 |
+
# B = torch.sum(torch.mul(A, V), 0).view(1, -1) # 1 x V
|
62 |
+
# else:
|
63 |
+
# temp = F.softmax(temp, 0) # N x 1
|
64 |
+
# A = torch.cat((A, temp), 1) # N x C
|
65 |
+
# B = torch.cat((B, torch.sum(torch.mul(temp, V), 0).view(1, -1)), 0) # C x V -> 1 x C x V
|
66 |
+
|
67 |
+
B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V
|
68 |
+
C = self.fcc(B) # 1 x C x 1
|
69 |
+
C = C.view(1, -1)
|
70 |
+
return C, A, B
|
71 |
+
|
72 |
+
class MILNet(nn.Module):
|
73 |
+
def __init__(self, i_classifier, b_classifier):
|
74 |
+
super(MILNet, self).__init__()
|
75 |
+
self.i_classifier = i_classifier
|
76 |
+
self.b_classifier = b_classifier
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
feats, classes = self.i_classifier(x)
|
80 |
+
prediction_bag, A, B = self.b_classifier(feats, classes)
|
81 |
+
|
82 |
+
return classes, prediction_bag, A, B
|
83 |
+
|
feature_extractor/config.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
batch_size: 256
|
2 |
+
epochs: 20
|
3 |
+
eval_every_n_epochs: 1
|
4 |
+
fine_tune_from: ''
|
5 |
+
log_every_n_steps: 25
|
6 |
+
weight_decay: 10e-6
|
7 |
+
fp16_precision: False
|
8 |
+
n_gpu: 2
|
9 |
+
gpu_ids: (0,1)
|
10 |
+
|
11 |
+
model:
|
12 |
+
out_dim: 512
|
13 |
+
base_model: "resnet18"
|
14 |
+
|
15 |
+
dataset:
|
16 |
+
s: 1
|
17 |
+
input_shape: (224,224,3)
|
18 |
+
num_workers: 10
|
19 |
+
valid_size: 0.1
|
20 |
+
|
21 |
+
loss:
|
22 |
+
temperature: 0.5
|
23 |
+
use_cosine_similarity: True
|
feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-36.pyc
ADDED
Binary file (3.83 kB). View file
|
|
feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-38.pyc
ADDED
Binary file (4 kB). View file
|
|
feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-36.pyc
ADDED
Binary file (896 Bytes). View file
|
|
feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-38.pyc
ADDED
Binary file (932 Bytes). View file
|
|
feature_extractor/data_aug/dataset_wrapper.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torch.utils.data.sampler import SubsetRandomSampler
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
from data_aug.gaussian_blur import GaussianBlur
|
6 |
+
from torchvision import datasets
|
7 |
+
import pandas as pd
|
8 |
+
from PIL import Image
|
9 |
+
from skimage import io, img_as_ubyte
|
10 |
+
|
11 |
+
np.random.seed(0)
|
12 |
+
|
13 |
+
class Dataset():
|
14 |
+
def __init__(self, csv_file, transform=None):
|
15 |
+
lines = []
|
16 |
+
with open(csv_file) as f:
|
17 |
+
for line in f:
|
18 |
+
line = line.rstrip().strip()
|
19 |
+
lines.append(line)
|
20 |
+
self.files_list = lines#pd.read_csv(csv_file)
|
21 |
+
self.transform = transform
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.files_list)
|
24 |
+
def __getitem__(self, idx):
|
25 |
+
temp_path = self.files_list[idx]# self.files_list.iloc[idx, 0]
|
26 |
+
img = Image.open(temp_path)
|
27 |
+
img = transforms.functional.to_tensor(img)
|
28 |
+
if self.transform:
|
29 |
+
sample = self.transform(img)
|
30 |
+
return sample
|
31 |
+
|
32 |
+
class ToPIL(object):
|
33 |
+
def __call__(self, sample):
|
34 |
+
img = sample
|
35 |
+
img = transforms.functional.to_pil_image(img)
|
36 |
+
return img
|
37 |
+
|
38 |
+
class DataSetWrapper(object):
|
39 |
+
|
40 |
+
def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
|
41 |
+
self.batch_size = batch_size
|
42 |
+
self.num_workers = num_workers
|
43 |
+
self.valid_size = valid_size
|
44 |
+
self.s = s
|
45 |
+
self.input_shape = eval(input_shape)
|
46 |
+
|
47 |
+
def get_data_loaders(self):
|
48 |
+
data_augment = self._get_simclr_pipeline_transform()
|
49 |
+
train_dataset = Dataset(csv_file='all_patches.csv', transform=SimCLRDataTransform(data_augment))
|
50 |
+
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset)
|
51 |
+
return train_loader, valid_loader
|
52 |
+
|
53 |
+
def _get_simclr_pipeline_transform(self):
|
54 |
+
# get a set of data augmentation transformations as described in the SimCLR paper.
|
55 |
+
color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
|
56 |
+
data_transforms = transforms.Compose([ToPIL(),
|
57 |
+
# transforms.RandomResizedCrop(size=self.input_shape[0]),
|
58 |
+
transforms.Resize((self.input_shape[0],self.input_shape[1])),
|
59 |
+
transforms.RandomHorizontalFlip(),
|
60 |
+
transforms.RandomApply([color_jitter], p=0.8),
|
61 |
+
transforms.RandomGrayscale(p=0.2),
|
62 |
+
GaussianBlur(kernel_size=int(0.06 * self.input_shape[0])),
|
63 |
+
transforms.ToTensor()])
|
64 |
+
return data_transforms
|
65 |
+
|
66 |
+
def get_train_validation_data_loaders(self, train_dataset):
|
67 |
+
# obtain training indices that will be used for validation
|
68 |
+
num_train = len(train_dataset)
|
69 |
+
indices = list(range(num_train))
|
70 |
+
np.random.shuffle(indices)
|
71 |
+
|
72 |
+
split = int(np.floor(self.valid_size * num_train))
|
73 |
+
train_idx, valid_idx = indices[split:], indices[:split]
|
74 |
+
|
75 |
+
# define samplers for obtaining training and validation batches
|
76 |
+
train_sampler = SubsetRandomSampler(train_idx)
|
77 |
+
valid_sampler = SubsetRandomSampler(valid_idx)
|
78 |
+
|
79 |
+
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,
|
80 |
+
num_workers=self.num_workers, drop_last=True, shuffle=False)
|
81 |
+
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler,
|
82 |
+
num_workers=self.num_workers, drop_last=True)
|
83 |
+
return train_loader, valid_loader
|
84 |
+
|
85 |
+
|
86 |
+
class SimCLRDataTransform(object):
|
87 |
+
def __init__(self, transform):
|
88 |
+
self.transform = transform
|
89 |
+
|
90 |
+
def __call__(self, sample):
|
91 |
+
xi = self.transform(sample)
|
92 |
+
xj = self.transform(sample)
|
93 |
+
return xi, xj
|
feature_extractor/data_aug/gaussian_blur.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
np.random.seed(0)
|
5 |
+
|
6 |
+
|
7 |
+
class GaussianBlur(object):
|
8 |
+
# Implements Gaussian blur as described in the SimCLR paper
|
9 |
+
def __init__(self, kernel_size, min=0.1, max=2.0):
|
10 |
+
self.min = min
|
11 |
+
self.max = max
|
12 |
+
# kernel size is set to be 10% of the image height/width
|
13 |
+
self.kernel_size = kernel_size
|
14 |
+
|
15 |
+
def __call__(self, sample):
|
16 |
+
sample = np.array(sample)
|
17 |
+
|
18 |
+
# blur the image with a 50% chance
|
19 |
+
prob = np.random.random_sample()
|
20 |
+
|
21 |
+
if prob < 0.5:
|
22 |
+
# print(self.kernel_size)
|
23 |
+
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
24 |
+
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
25 |
+
|
26 |
+
return sample
|
feature_extractor/load_patches.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os, glob
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
def main():
|
6 |
+
parser = argparse.ArgumentParser()
|
7 |
+
parser.add_argument('--data_path', type=str)
|
8 |
+
args = parser.parse_args()
|
9 |
+
|
10 |
+
wsi_slides_paths = []
|
11 |
+
|
12 |
+
|
13 |
+
def r(dirpath):
|
14 |
+
for file in os.listdir(dirpath):
|
15 |
+
path = os.path.join(dirpath, file)
|
16 |
+
if os.path.isfile(path) and file.endswith(".svs"):
|
17 |
+
wsi_slides_paths.append(path)
|
18 |
+
elif os.path.isdir(path):
|
19 |
+
r(path)
|
20 |
+
def r(dirpath):
|
21 |
+
for path in glob.glob(os.path.join(dirpath, '*','*.svs') ):#os.listdir(dirpath):
|
22 |
+
if os.path.isfile(path):
|
23 |
+
wsi_slides_paths.append(path)
|
24 |
+
def r(dirpath):
|
25 |
+
for path in glob.glob(os.path.join(dirpath, '*', '*', '*.jpeg') ):#os.listdir(dirpath):
|
26 |
+
if os.path.isfile(path):
|
27 |
+
wsi_slides_paths.append(path)
|
28 |
+
r(args.data_path)
|
29 |
+
with open('all_patches.csv', 'w') as f:
|
30 |
+
for filepath in wsi_slides_paths:
|
31 |
+
f.write(f'{filepath}\n')
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
main()
|
feature_extractor/loss/__pycache__/nt_xent.cpython-36.pyc
ADDED
Binary file (2.45 kB). View file
|
|
feature_extractor/loss/__pycache__/nt_xent.cpython-38.pyc
ADDED
Binary file (2.49 kB). View file
|
|
feature_extractor/loss/nt_xent.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class NTXentLoss(torch.nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
|
8 |
+
super(NTXentLoss, self).__init__()
|
9 |
+
self.batch_size = batch_size
|
10 |
+
self.temperature = temperature
|
11 |
+
self.device = device
|
12 |
+
self.softmax = torch.nn.Softmax(dim=-1)
|
13 |
+
self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
|
14 |
+
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
|
15 |
+
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
|
16 |
+
|
17 |
+
def _get_similarity_function(self, use_cosine_similarity):
|
18 |
+
if use_cosine_similarity:
|
19 |
+
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
|
20 |
+
return self._cosine_simililarity
|
21 |
+
else:
|
22 |
+
return self._dot_simililarity
|
23 |
+
|
24 |
+
def _get_correlated_mask(self):
|
25 |
+
diag = np.eye(2 * self.batch_size)
|
26 |
+
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
|
27 |
+
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
|
28 |
+
mask = torch.from_numpy((diag + l1 + l2))
|
29 |
+
mask = (1 - mask).type(torch.bool)
|
30 |
+
return mask.to(self.device)
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def _dot_simililarity(x, y):
|
34 |
+
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
|
35 |
+
# x shape: (N, 1, C)
|
36 |
+
# y shape: (1, C, 2N)
|
37 |
+
# v shape: (N, 2N)
|
38 |
+
return v
|
39 |
+
|
40 |
+
def _cosine_simililarity(self, x, y):
|
41 |
+
# x shape: (N, 1, C)
|
42 |
+
# y shape: (1, 2N, C)
|
43 |
+
# v shape: (N, 2N)
|
44 |
+
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
|
45 |
+
return v
|
46 |
+
|
47 |
+
def forward(self, zis, zjs):
|
48 |
+
representations = torch.cat([zjs, zis], dim=0)
|
49 |
+
|
50 |
+
similarity_matrix = self.similarity_function(representations, representations)
|
51 |
+
|
52 |
+
# filter out the scores from the positive samples
|
53 |
+
l_pos = torch.diag(similarity_matrix, self.batch_size)
|
54 |
+
r_pos = torch.diag(similarity_matrix, -self.batch_size)
|
55 |
+
positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
|
56 |
+
|
57 |
+
negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
|
58 |
+
|
59 |
+
logits = torch.cat((positives, negatives), dim=1)
|
60 |
+
logits /= self.temperature
|
61 |
+
|
62 |
+
labels = torch.zeros(2 * self.batch_size).to(self.device).long()
|
63 |
+
loss = self.criterion(logits, labels)
|
64 |
+
|
65 |
+
return loss / (2 * self.batch_size)
|
feature_extractor/models/__init__.py
ADDED
File without changes
|
feature_extractor/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (186 Bytes). View file
|
|
feature_extractor/models/__pycache__/resnet_simclr.cpython-36.pyc
ADDED
Binary file (1.51 kB). View file
|
|
feature_extractor/models/__pycache__/resnet_simclr.cpython-38.pyc
ADDED
Binary file (1.55 kB). View file
|
|
feature_extractor/models/baseline_encoder.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as models
|
5 |
+
|
6 |
+
|
7 |
+
class Encoder(nn.Module):
|
8 |
+
def __init__(self, out_dim=64):
|
9 |
+
super(Encoder, self).__init__()
|
10 |
+
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
11 |
+
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
12 |
+
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
13 |
+
self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
14 |
+
self.pool = nn.MaxPool2d(2, 2)
|
15 |
+
|
16 |
+
# projection MLP
|
17 |
+
self.l1 = nn.Linear(64, 64)
|
18 |
+
self.l2 = nn.Linear(64, out_dim)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = self.conv1(x)
|
22 |
+
x = F.relu(x)
|
23 |
+
x = self.pool(x)
|
24 |
+
|
25 |
+
x = self.conv2(x)
|
26 |
+
x = F.relu(x)
|
27 |
+
x = self.pool(x)
|
28 |
+
|
29 |
+
x = self.conv3(x)
|
30 |
+
x = F.relu(x)
|
31 |
+
x = self.pool(x)
|
32 |
+
|
33 |
+
x = self.conv4(x)
|
34 |
+
x = F.relu(x)
|
35 |
+
x = self.pool(x)
|
36 |
+
|
37 |
+
h = torch.mean(x, dim=[2, 3])
|
38 |
+
|
39 |
+
x = self.l1(h)
|
40 |
+
x = F.relu(x)
|
41 |
+
x = self.l2(x)
|
42 |
+
|
43 |
+
return h, x
|
feature_extractor/models/resnet_simclr.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torchvision.models as models
|
4 |
+
|
5 |
+
|
6 |
+
class ResNetSimCLR(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, base_model, out_dim):
|
9 |
+
super(ResNetSimCLR, self).__init__()
|
10 |
+
self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d),
|
11 |
+
"resnet50": models.resnet50(pretrained=False)}
|
12 |
+
|
13 |
+
resnet = self._get_basemodel(base_model)
|
14 |
+
num_ftrs = resnet.fc.in_features
|
15 |
+
|
16 |
+
self.features = nn.Sequential(*list(resnet.children())[:-1])
|
17 |
+
|
18 |
+
# projection MLP
|
19 |
+
self.l1 = nn.Linear(num_ftrs, num_ftrs)
|
20 |
+
self.l2 = nn.Linear(num_ftrs, out_dim)
|
21 |
+
|
22 |
+
def _get_basemodel(self, model_name):
|
23 |
+
try:
|
24 |
+
model = self.resnet_dict[model_name]
|
25 |
+
print("Feature extractor:", model_name)
|
26 |
+
return model
|
27 |
+
except:
|
28 |
+
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
h = self.features(x)
|
32 |
+
h = h.squeeze()
|
33 |
+
|
34 |
+
x = self.l1(h)
|
35 |
+
x = F.relu(x)
|
36 |
+
x = self.l2(x)
|
37 |
+
return h, x
|
feature_extractor/run.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from simclr import SimCLR
|
2 |
+
import yaml
|
3 |
+
from data_aug.dataset_wrapper import DataSetWrapper
|
4 |
+
import os, glob
|
5 |
+
import pandas as pd
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
def main():
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument('--magnification', type=str, default='20x')
|
11 |
+
parser.add_argument('--dest_weights', type=str)
|
12 |
+
args = parser.parse_args()
|
13 |
+
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
|
14 |
+
dataset = DataSetWrapper(config['batch_size'], **config['dataset'])
|
15 |
+
|
16 |
+
simclr = SimCLR(dataset, config, args)
|
17 |
+
simclr.train()
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
main()
|
feature_extractor/simclr.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models.resnet_simclr import ResNetSimCLR
|
3 |
+
from torch.utils.tensorboard import SummaryWriter
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from loss.nt_xent import NTXentLoss
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
import sys
|
9 |
+
|
10 |
+
apex_support = False
|
11 |
+
try:
|
12 |
+
sys.path.append('./apex')
|
13 |
+
from apex import amp
|
14 |
+
|
15 |
+
apex_support = True
|
16 |
+
except:
|
17 |
+
print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex")
|
18 |
+
apex_support = False
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
torch.manual_seed(0)
|
23 |
+
|
24 |
+
|
25 |
+
def _save_config_file(model_checkpoints_folder):
|
26 |
+
if not os.path.exists(model_checkpoints_folder):
|
27 |
+
os.makedirs(model_checkpoints_folder)
|
28 |
+
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml'))
|
29 |
+
|
30 |
+
|
31 |
+
class SimCLR(object):
|
32 |
+
|
33 |
+
def __init__(self, dataset, config, args=None):
|
34 |
+
self.config = config
|
35 |
+
self.device = self._get_device()
|
36 |
+
self.writer = SummaryWriter()
|
37 |
+
self.dataset = dataset
|
38 |
+
self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss'])
|
39 |
+
self.args = args
|
40 |
+
def _get_device(self):
|
41 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
42 |
+
print("Running on:", device)
|
43 |
+
return device
|
44 |
+
|
45 |
+
def _step(self, model, xis, xjs, n_iter):
|
46 |
+
|
47 |
+
# get the representations and the projections
|
48 |
+
ris, zis = model(xis) # [N,C]
|
49 |
+
|
50 |
+
# get the representations and the projections
|
51 |
+
rjs, zjs = model(xjs) # [N,C]
|
52 |
+
|
53 |
+
# normalize projection feature vectors
|
54 |
+
zis = F.normalize(zis, dim=1)
|
55 |
+
zjs = F.normalize(zjs, dim=1)
|
56 |
+
|
57 |
+
loss = self.nt_xent_criterion(zis, zjs)
|
58 |
+
return loss
|
59 |
+
|
60 |
+
def train(self):
|
61 |
+
|
62 |
+
train_loader, valid_loader = self.dataset.get_data_loaders()
|
63 |
+
|
64 |
+
model = ResNetSimCLR(**self.config["model"])# .to(self.device)
|
65 |
+
if self.config['n_gpu'] > 1:
|
66 |
+
model = torch.nn.DataParallel(model, device_ids=eval(self.config['gpu_ids']))
|
67 |
+
model = self._load_pre_trained_weights(model)
|
68 |
+
model = model.to(self.device)
|
69 |
+
|
70 |
+
|
71 |
+
optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=eval(self.config['weight_decay']))
|
72 |
+
|
73 |
+
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
|
74 |
+
# last_epoch=-1)
|
75 |
+
|
76 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config['epochs'], eta_min=0,
|
77 |
+
last_epoch=-1)
|
78 |
+
|
79 |
+
|
80 |
+
if apex_support and self.config['fp16_precision']:
|
81 |
+
model, optimizer = amp.initialize(model, optimizer,
|
82 |
+
opt_level='O2',
|
83 |
+
keep_batchnorm_fp32=True)
|
84 |
+
|
85 |
+
if self.args is None:
|
86 |
+
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
|
87 |
+
else:
|
88 |
+
model_checkpoints_folder = self.args.dest_weights#os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH']
|
89 |
+
model_checkpoints_folder = os.path.dirname(model_checkpoints_folder)
|
90 |
+
# save config file
|
91 |
+
_save_config_file(model_checkpoints_folder)
|
92 |
+
|
93 |
+
n_iter = 0
|
94 |
+
valid_n_iter = 0
|
95 |
+
best_valid_loss = np.inf
|
96 |
+
|
97 |
+
for epoch_counter in range(self.config['epochs']):
|
98 |
+
for (xis, xjs) in train_loader:
|
99 |
+
optimizer.zero_grad()
|
100 |
+
xis = xis.to(self.device)
|
101 |
+
xjs = xjs.to(self.device)
|
102 |
+
|
103 |
+
loss = self._step(model, xis, xjs, n_iter)
|
104 |
+
|
105 |
+
if n_iter % self.config['log_every_n_steps'] == 0:
|
106 |
+
self.writer.add_scalar('train_loss', loss, global_step=n_iter)
|
107 |
+
print("[%d/%d] step: %d train_loss: %.3f" % (epoch_counter, self.config['epochs'], n_iter, loss))
|
108 |
+
|
109 |
+
if apex_support and self.config['fp16_precision']:
|
110 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
111 |
+
scaled_loss.backward()
|
112 |
+
else:
|
113 |
+
loss.backward()
|
114 |
+
|
115 |
+
optimizer.step()
|
116 |
+
n_iter += 1
|
117 |
+
|
118 |
+
# validate the model if requested
|
119 |
+
if epoch_counter % self.config['eval_every_n_epochs'] == 0:
|
120 |
+
valid_loss = self._validate(model, valid_loader)
|
121 |
+
print("[%d/%d] val_loss: %.3f" % (epoch_counter, self.config['epochs'], valid_loss))
|
122 |
+
if valid_loss < best_valid_loss:
|
123 |
+
# save the model weights
|
124 |
+
best_valid_loss = valid_loss
|
125 |
+
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))
|
126 |
+
print('saved')
|
127 |
+
|
128 |
+
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
|
129 |
+
valid_n_iter += 1
|
130 |
+
|
131 |
+
# warmup for the first 10 epochs
|
132 |
+
if epoch_counter >= 10:
|
133 |
+
scheduler.step()
|
134 |
+
self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
|
135 |
+
|
136 |
+
def _load_pre_trained_weights(self, model):
|
137 |
+
try:
|
138 |
+
checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints')
|
139 |
+
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
|
140 |
+
model.load_state_dict(state_dict)
|
141 |
+
print("Loaded pre-trained model with success.")
|
142 |
+
except FileNotFoundError:
|
143 |
+
print("Pre-trained weights not found. Training from scratch.")
|
144 |
+
|
145 |
+
return model
|
146 |
+
|
147 |
+
def _validate(self, model, valid_loader):
|
148 |
+
|
149 |
+
# validation steps
|
150 |
+
with torch.no_grad():
|
151 |
+
model.eval()
|
152 |
+
|
153 |
+
valid_loss = 0.0
|
154 |
+
counter = 0
|
155 |
+
|
156 |
+
for (xis, xjs) in valid_loader:
|
157 |
+
xis = xis.to(self.device)
|
158 |
+
xjs = xjs.to(self.device)
|
159 |
+
|
160 |
+
loss = self._step(model, xis, xjs, counter)
|
161 |
+
valid_loss += loss.item()
|
162 |
+
counter += 1
|
163 |
+
valid_loss /= counter
|
164 |
+
model.train()
|
165 |
+
return valid_loss
|
feature_extractor/viewer.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
#
|
3 |
+
# deepzoom_server - Example web application for serving whole-slide images
|
4 |
+
#
|
5 |
+
# Copyright (c) 2010-2015 Carnegie Mellon University
|
6 |
+
#
|
7 |
+
# This library is free software; you can redistribute it and/or modify it
|
8 |
+
# under the terms of version 2.1 of the GNU Lesser General Public License
|
9 |
+
# as published by the Free Software Foundation.
|
10 |
+
#
|
11 |
+
# This library is distributed in the hope that it will be useful, but
|
12 |
+
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
|
13 |
+
# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
14 |
+
# License for more details.
|
15 |
+
#
|
16 |
+
# You should have received a copy of the GNU Lesser General Public License
|
17 |
+
# along with this library; if not, write to the Free Software Foundation,
|
18 |
+
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
19 |
+
#
|
20 |
+
|
21 |
+
from io import BytesIO
|
22 |
+
from optparse import OptionParser
|
23 |
+
import os
|
24 |
+
import re
|
25 |
+
from unicodedata import normalize
|
26 |
+
|
27 |
+
from flask import Flask, abort, make_response, render_template, url_for
|
28 |
+
|
29 |
+
if os.name == 'nt':
|
30 |
+
_dll_path = os.getenv('OPENSLIDE_PATH')
|
31 |
+
if _dll_path is not None:
|
32 |
+
if hasattr(os, 'add_dll_directory'):
|
33 |
+
# Python >= 3.8
|
34 |
+
with os.add_dll_directory(_dll_path):
|
35 |
+
import openslide
|
36 |
+
else:
|
37 |
+
# Python < 3.8
|
38 |
+
_orig_path = os.environ.get('PATH', '')
|
39 |
+
os.environ['PATH'] = _orig_path + ';' + _dll_path
|
40 |
+
import openslide
|
41 |
+
|
42 |
+
os.environ['PATH'] = _orig_path
|
43 |
+
else:
|
44 |
+
import openslide
|
45 |
+
|
46 |
+
from openslide import ImageSlide, open_slide
|
47 |
+
from openslide.deepzoom import DeepZoomGenerator
|
48 |
+
|
49 |
+
DEEPZOOM_SLIDE = None
|
50 |
+
DEEPZOOM_FORMAT = 'jpeg'
|
51 |
+
DEEPZOOM_TILE_SIZE = 254
|
52 |
+
DEEPZOOM_OVERLAP = 1
|
53 |
+
DEEPZOOM_LIMIT_BOUNDS = True
|
54 |
+
DEEPZOOM_TILE_QUALITY = 75
|
55 |
+
SLIDE_NAME = 'slide'
|
56 |
+
|
57 |
+
app = Flask(__name__)
|
58 |
+
app.config.from_object(__name__)
|
59 |
+
app.config.from_envvar('DEEPZOOM_TILER_SETTINGS', silent=True)
|
60 |
+
|
61 |
+
|
62 |
+
@app.before_first_request
|
63 |
+
def load_slide():
|
64 |
+
slidefile = app.config['DEEPZOOM_SLIDE']
|
65 |
+
if slidefile is None:
|
66 |
+
raise ValueError('No slide file specified')
|
67 |
+
config_map = {
|
68 |
+
'DEEPZOOM_TILE_SIZE': 'tile_size',
|
69 |
+
'DEEPZOOM_OVERLAP': 'overlap',
|
70 |
+
'DEEPZOOM_LIMIT_BOUNDS': 'limit_bounds',
|
71 |
+
}
|
72 |
+
opts = {v: app.config[k] for k, v in config_map.items()}
|
73 |
+
slide = open_slide(slidefile)
|
74 |
+
app.slides = {SLIDE_NAME: DeepZoomGenerator(slide, **opts)}
|
75 |
+
app.associated_images = []
|
76 |
+
app.slide_properties = slide.properties
|
77 |
+
for name, image in slide.associated_images.items():
|
78 |
+
app.associated_images.append(name)
|
79 |
+
slug = slugify(name)
|
80 |
+
app.slides[slug] = DeepZoomGenerator(ImageSlide(image), **opts)
|
81 |
+
try:
|
82 |
+
mpp_x = slide.properties[openslide.PROPERTY_NAME_MPP_X]
|
83 |
+
mpp_y = slide.properties[openslide.PROPERTY_NAME_MPP_Y]
|
84 |
+
app.slide_mpp = (float(mpp_x) + float(mpp_y)) / 2
|
85 |
+
except (KeyError, ValueError):
|
86 |
+
app.slide_mpp = 0
|
87 |
+
|
88 |
+
|
89 |
+
@app.route('/')
|
90 |
+
def index():
|
91 |
+
slide_url = url_for('dzi', slug=SLIDE_NAME)
|
92 |
+
associated_urls = {
|
93 |
+
name: url_for('dzi', slug=slugify(name)) for name in app.associated_images
|
94 |
+
}
|
95 |
+
return render_template(
|
96 |
+
'slide-multipane.html',
|
97 |
+
slide_url=slide_url,
|
98 |
+
associated=associated_urls,
|
99 |
+
properties=app.slide_properties,
|
100 |
+
slide_mpp=app.slide_mpp,
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
@app.route('/<slug>.dzi')
|
105 |
+
def dzi(slug):
|
106 |
+
format = app.config['DEEPZOOM_FORMAT']
|
107 |
+
try:
|
108 |
+
resp = make_response(app.slides[slug].get_dzi(format))
|
109 |
+
resp.mimetype = 'application/xml'
|
110 |
+
return resp
|
111 |
+
except KeyError:
|
112 |
+
# Unknown slug
|
113 |
+
abort(404)
|
114 |
+
|
115 |
+
|
116 |
+
@app.route('/<slug>_files/<int:level>/<int:col>_<int:row>.<format>')
|
117 |
+
def tile(slug, level, col, row, format):
|
118 |
+
format = format.lower()
|
119 |
+
if format != 'jpeg' and format != 'png':
|
120 |
+
# Not supported by Deep Zoom
|
121 |
+
abort(404)
|
122 |
+
try:
|
123 |
+
tile = app.slides[slug].get_tile(level, (col, row))
|
124 |
+
except KeyError:
|
125 |
+
# Unknown slug
|
126 |
+
abort(404)
|
127 |
+
except ValueError:
|
128 |
+
# Invalid level or coordinates
|
129 |
+
abort(404)
|
130 |
+
buf = BytesIO()
|
131 |
+
tile.save(buf, format, quality=app.config['DEEPZOOM_TILE_QUALITY'])
|
132 |
+
resp = make_response(buf.getvalue())
|
133 |
+
resp.mimetype = 'image/%s' % format
|
134 |
+
return resp
|
135 |
+
|
136 |
+
|
137 |
+
def slugify(text):
|
138 |
+
text = normalize('NFKD', text.lower()).encode('ascii', 'ignore').decode()
|
139 |
+
return re.sub('[^a-z0-9]+', '-', text)
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == '__main__':
|
143 |
+
parser = OptionParser(usage='Usage: %prog [options] [slide]')
|
144 |
+
parser.add_option(
|
145 |
+
'-B',
|
146 |
+
'--ignore-bounds',
|
147 |
+
dest='DEEPZOOM_LIMIT_BOUNDS',
|
148 |
+
default=True,
|
149 |
+
action='store_false',
|
150 |
+
help='display entire scan area',
|
151 |
+
)
|
152 |
+
parser.add_option(
|
153 |
+
'-c', '--config', metavar='FILE', dest='config', help='config file'
|
154 |
+
)
|
155 |
+
parser.add_option(
|
156 |
+
'-d',
|
157 |
+
'--debug',
|
158 |
+
dest='DEBUG',
|
159 |
+
action='store_true',
|
160 |
+
help='run in debugging mode (insecure)',
|
161 |
+
)
|
162 |
+
parser.add_option(
|
163 |
+
'-e',
|
164 |
+
'--overlap',
|
165 |
+
metavar='PIXELS',
|
166 |
+
dest='DEEPZOOM_OVERLAP',
|
167 |
+
type='int',
|
168 |
+
help='overlap of adjacent tiles [1]',
|
169 |
+
)
|
170 |
+
parser.add_option(
|
171 |
+
'-f',
|
172 |
+
'--format',
|
173 |
+
metavar='{jpeg|png}',
|
174 |
+
dest='DEEPZOOM_FORMAT',
|
175 |
+
help='image format for tiles [jpeg]',
|
176 |
+
)
|
177 |
+
parser.add_option(
|
178 |
+
'-l',
|
179 |
+
'--listen',
|
180 |
+
metavar='ADDRESS',
|
181 |
+
dest='host',
|
182 |
+
default='127.0.0.1',
|
183 |
+
help='address to listen on [127.0.0.1]',
|
184 |
+
)
|
185 |
+
parser.add_option(
|
186 |
+
'-p',
|
187 |
+
'--port',
|
188 |
+
metavar='PORT',
|
189 |
+
dest='port',
|
190 |
+
type='int',
|
191 |
+
default=5000,
|
192 |
+
help='port to listen on [5000]',
|
193 |
+
)
|
194 |
+
parser.add_option(
|
195 |
+
'-Q',
|
196 |
+
'--quality',
|
197 |
+
metavar='QUALITY',
|
198 |
+
dest='DEEPZOOM_TILE_QUALITY',
|
199 |
+
type='int',
|
200 |
+
help='JPEG compression quality [75]',
|
201 |
+
)
|
202 |
+
parser.add_option(
|
203 |
+
'-s',
|
204 |
+
'--size',
|
205 |
+
metavar='PIXELS',
|
206 |
+
dest='DEEPZOOM_TILE_SIZE',
|
207 |
+
type='int',
|
208 |
+
help='tile size [254]',
|
209 |
+
)
|
210 |
+
|
211 |
+
(opts, args) = parser.parse_args()
|
212 |
+
# Load config file if specified
|
213 |
+
if opts.config is not None:
|
214 |
+
app.config.from_pyfile(opts.config)
|
215 |
+
# Overwrite only those settings specified on the command line
|
216 |
+
for k in dir(opts):
|
217 |
+
if not k.startswith('_') and getattr(opts, k) is None:
|
218 |
+
delattr(opts, k)
|
219 |
+
app.config.from_object(opts)
|
220 |
+
# Set slide file
|
221 |
+
try:
|
222 |
+
app.config['DEEPZOOM_SLIDE'] = args[0]
|
223 |
+
except IndexError:
|
224 |
+
if app.config['DEEPZOOM_SLIDE'] is None:
|
225 |
+
parser.error('No slide file specified')
|
226 |
+
|
227 |
+
app.run(host=opts.host, port=opts.port, threaded=True)
|
helper.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
from __future__ import absolute_import, division, print_function
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.autograd import Variable
|
12 |
+
from torchvision import transforms
|
13 |
+
from utils.metrics import ConfusionMatrix
|
14 |
+
from PIL import Image
|
15 |
+
import os
|
16 |
+
|
17 |
+
# torch.cuda.synchronize()
|
18 |
+
# torch.backends.cudnn.benchmark = True
|
19 |
+
torch.backends.cudnn.deterministic = True
|
20 |
+
|
21 |
+
def collate(batch):
|
22 |
+
image = [ b['image'] for b in batch ] # w, h
|
23 |
+
label = [ b['label'] for b in batch ]
|
24 |
+
id = [ b['id'] for b in batch ]
|
25 |
+
adj_s = [ b['adj_s'] for b in batch ]
|
26 |
+
return {'image': image, 'label': label, 'id': id, 'adj_s': adj_s}
|
27 |
+
|
28 |
+
def preparefeatureLabel(batch_graph, batch_label, batch_adjs, device='cpu'):
|
29 |
+
batch_size = len(batch_graph)
|
30 |
+
labels = torch.LongTensor(batch_size)
|
31 |
+
max_node_num = 0
|
32 |
+
|
33 |
+
for i in range(batch_size):
|
34 |
+
labels[i] = batch_label[i]
|
35 |
+
max_node_num = max(max_node_num, batch_graph[i].shape[0])
|
36 |
+
|
37 |
+
masks = torch.zeros(batch_size, max_node_num)
|
38 |
+
adjs = torch.zeros(batch_size, max_node_num, max_node_num)
|
39 |
+
batch_node_feat = torch.zeros(batch_size, max_node_num, 512)
|
40 |
+
|
41 |
+
for i in range(batch_size):
|
42 |
+
cur_node_num = batch_graph[i].shape[0]
|
43 |
+
#node attribute feature
|
44 |
+
tmp_node_fea = batch_graph[i]
|
45 |
+
batch_node_feat[i, 0:cur_node_num] = tmp_node_fea
|
46 |
+
|
47 |
+
#adjs
|
48 |
+
adjs[i, 0:cur_node_num, 0:cur_node_num] = batch_adjs[i]
|
49 |
+
|
50 |
+
#masks
|
51 |
+
masks[i,0:cur_node_num] = 1
|
52 |
+
|
53 |
+
node_feat = batch_node_feat.to(device)
|
54 |
+
labels = labels.to(device)
|
55 |
+
adjs = adjs.to(device)
|
56 |
+
masks = masks.to(device)
|
57 |
+
|
58 |
+
return node_feat, labels, adjs, masks
|
59 |
+
|
60 |
+
class Trainer(object):
|
61 |
+
def __init__(self, n_class):
|
62 |
+
self.metrics = ConfusionMatrix(n_class)
|
63 |
+
|
64 |
+
def get_scores(self):
|
65 |
+
acc = self.metrics.get_scores()
|
66 |
+
|
67 |
+
return acc
|
68 |
+
|
69 |
+
def reset_metrics(self):
|
70 |
+
self.metrics.reset()
|
71 |
+
|
72 |
+
def plot_cm(self):
|
73 |
+
self.metrics.plotcm()
|
74 |
+
|
75 |
+
def train(self, sample, model):
|
76 |
+
node_feat, labels, adjs, masks = preparefeatureLabel(sample['image'], sample['label'], sample['adj_s'])
|
77 |
+
pred,labels,loss = model.forward(node_feat, labels, adjs, masks)
|
78 |
+
|
79 |
+
return pred,labels,loss
|
80 |
+
|
81 |
+
class Evaluator(object):
|
82 |
+
def __init__(self, n_class):
|
83 |
+
self.metrics = ConfusionMatrix(n_class)
|
84 |
+
|
85 |
+
def get_scores(self):
|
86 |
+
acc = self.metrics.get_scores()
|
87 |
+
|
88 |
+
return acc
|
89 |
+
|
90 |
+
def reset_metrics(self):
|
91 |
+
self.metrics.reset()
|
92 |
+
|
93 |
+
def plot_cm(self):
|
94 |
+
self.metrics.plotcm()
|
95 |
+
|
96 |
+
def eval_test(self, sample, model, graphcam_flag=False):
|
97 |
+
node_feat, labels, adjs, masks = preparefeatureLabel(sample['image'], sample['label'], sample['adj_s'])
|
98 |
+
if not graphcam_flag:
|
99 |
+
with torch.no_grad():
|
100 |
+
pred,labels,loss = model.forward(node_feat, labels, adjs, masks)
|
101 |
+
else:
|
102 |
+
torch.set_grad_enabled(True)
|
103 |
+
pred,labels,loss= model.forward(node_feat, labels, adjs, masks, graphcam_flag=graphcam_flag)
|
104 |
+
return pred,labels,loss
|
main.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
from __future__ import absolute_import, division, print_function
|
5 |
+
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
from utils.dataset import GraphDataset
|
13 |
+
from utils.lr_scheduler import LR_Scheduler
|
14 |
+
from tensorboardX import SummaryWriter
|
15 |
+
from helper import Trainer, Evaluator, collate
|
16 |
+
from option import Options
|
17 |
+
|
18 |
+
from models.GraphTransformer import Classifier
|
19 |
+
from models.weight_init import weight_init
|
20 |
+
import pickle
|
21 |
+
args = Options().parse()
|
22 |
+
|
23 |
+
label_map = pickle.load(open(os.path.join(args.dataset_metadata_path, 'label_map.pkl'), 'rb'))
|
24 |
+
|
25 |
+
n_class = len(label_map)
|
26 |
+
|
27 |
+
torch.cuda.synchronize()
|
28 |
+
torch.backends.cudnn.deterministic = True
|
29 |
+
|
30 |
+
data_path = args.data_path
|
31 |
+
model_path = args.model_path
|
32 |
+
if not os.path.isdir(model_path): os.mkdir(model_path)
|
33 |
+
log_path = args.log_path
|
34 |
+
if not os.path.isdir(log_path): os.mkdir(log_path)
|
35 |
+
task_name = args.task_name
|
36 |
+
|
37 |
+
print(task_name)
|
38 |
+
###################################
|
39 |
+
train = args.train
|
40 |
+
test = args.test
|
41 |
+
graphcam = args.graphcam
|
42 |
+
print("train:", train, "test:", test, "graphcam:", graphcam)
|
43 |
+
|
44 |
+
##### Load datasets
|
45 |
+
print("preparing datasets and dataloaders......")
|
46 |
+
batch_size = args.batch_size
|
47 |
+
|
48 |
+
if train:
|
49 |
+
ids_train = open(args.train_set).readlines()
|
50 |
+
dataset_train = GraphDataset(os.path.join(data_path, ""), ids_train, args.dataset_metadata_path)
|
51 |
+
dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=True, pin_memory=True, drop_last=True)
|
52 |
+
total_train_num = len(dataloader_train) * batch_size
|
53 |
+
|
54 |
+
ids_val = open(args.val_set).readlines()
|
55 |
+
dataset_val = GraphDataset(os.path.join(data_path, ""), ids_val, args.dataset_metadata_path)
|
56 |
+
dataloader_val = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=False, pin_memory=True)
|
57 |
+
total_val_num = len(dataloader_val) * batch_size
|
58 |
+
|
59 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
+
##### creating models #############
|
61 |
+
print("creating models......")
|
62 |
+
|
63 |
+
num_epochs = args.num_epochs
|
64 |
+
learning_rate = args.lr
|
65 |
+
|
66 |
+
model = Classifier(n_class)
|
67 |
+
model = nn.DataParallel(model)
|
68 |
+
if args.resume:
|
69 |
+
print('load model{}'.format(args.resume))
|
70 |
+
model.load_state_dict(torch.load(args.resume))
|
71 |
+
|
72 |
+
if torch.cuda.is_available():
|
73 |
+
model = model.cuda()
|
74 |
+
#model.apply(weight_init)
|
75 |
+
|
76 |
+
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay = 5e-4) # best:5e-4, 4e-3
|
77 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,100], gamma=0.1) # gamma=0.3 # 30,90,130 # 20,90,130 -> 150
|
78 |
+
|
79 |
+
##################################
|
80 |
+
|
81 |
+
criterion = nn.CrossEntropyLoss()
|
82 |
+
|
83 |
+
if not test:
|
84 |
+
writer = SummaryWriter(log_dir=log_path + task_name)
|
85 |
+
f_log = open(log_path + task_name + ".log", 'w')
|
86 |
+
|
87 |
+
trainer = Trainer(n_class)
|
88 |
+
evaluator = Evaluator(n_class)
|
89 |
+
|
90 |
+
best_pred = 0.0
|
91 |
+
for epoch in range(num_epochs):
|
92 |
+
# optimizer.zero_grad()
|
93 |
+
model.train()
|
94 |
+
train_loss = 0.
|
95 |
+
total = 0.
|
96 |
+
|
97 |
+
current_lr = optimizer.param_groups[0]['lr']
|
98 |
+
print('\n=>Epoches %i, learning rate = %.7f, previous best = %.4f' % (epoch+1, current_lr, best_pred))
|
99 |
+
|
100 |
+
if train:
|
101 |
+
for i_batch, sample_batched in enumerate(dataloader_train):
|
102 |
+
scheduler.step(epoch)
|
103 |
+
|
104 |
+
preds,labels,loss = trainer.train(sample_batched, model)
|
105 |
+
|
106 |
+
optimizer.zero_grad()
|
107 |
+
loss.backward()
|
108 |
+
optimizer.step()
|
109 |
+
|
110 |
+
train_loss += loss
|
111 |
+
total += len(labels)
|
112 |
+
|
113 |
+
trainer.metrics.update(labels, preds)
|
114 |
+
if (i_batch + 1) % args.log_interval_local == 0:
|
115 |
+
print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total, total_train_num, train_loss / total, trainer.get_scores()))
|
116 |
+
trainer.plot_cm()
|
117 |
+
|
118 |
+
if not test:
|
119 |
+
print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total_train_num, total_train_num, train_loss / total, trainer.get_scores()))
|
120 |
+
trainer.plot_cm()
|
121 |
+
|
122 |
+
|
123 |
+
if epoch % 1 == 0:
|
124 |
+
with torch.no_grad():
|
125 |
+
model.eval()
|
126 |
+
print("evaluating...")
|
127 |
+
|
128 |
+
total = 0.
|
129 |
+
batch_idx = 0
|
130 |
+
|
131 |
+
for i_batch, sample_batched in enumerate(dataloader_val):
|
132 |
+
preds, labels, _ = evaluator.eval_test(sample_batched, model, graphcam)
|
133 |
+
|
134 |
+
total += len(labels)
|
135 |
+
|
136 |
+
evaluator.metrics.update(labels, preds)
|
137 |
+
|
138 |
+
if (i_batch + 1) % args.log_interval_local == 0:
|
139 |
+
print('[%d/%d] val agg acc: %.3f' % (total, total_val_num, evaluator.get_scores()))
|
140 |
+
evaluator.plot_cm()
|
141 |
+
|
142 |
+
print('[%d/%d] val agg acc: %.3f' % (total_val_num, total_val_num, evaluator.get_scores()))
|
143 |
+
evaluator.plot_cm()
|
144 |
+
|
145 |
+
# torch.cuda.empty_cache()
|
146 |
+
|
147 |
+
val_acc = evaluator.get_scores()
|
148 |
+
if val_acc > best_pred:
|
149 |
+
best_pred = val_acc
|
150 |
+
if not test:
|
151 |
+
print("saving model...")
|
152 |
+
torch.save(model.state_dict(), model_path + task_name + ".pth")
|
153 |
+
|
154 |
+
log = ""
|
155 |
+
log = log + 'epoch [{}/{}] ------ acc: train = {:.4f}, val = {:.4f}'.format(epoch+1, num_epochs, trainer.get_scores(), evaluator.get_scores()) + "\n"
|
156 |
+
|
157 |
+
log += "================================\n"
|
158 |
+
print(log)
|
159 |
+
if test: break
|
160 |
+
|
161 |
+
f_log.write(log)
|
162 |
+
f_log.flush()
|
163 |
+
|
164 |
+
writer.add_scalars('accuracy', {'train acc': trainer.get_scores(), 'val acc': evaluator.get_scores()}, epoch+1)
|
165 |
+
|
166 |
+
trainer.reset_metrics()
|
167 |
+
evaluator.reset_metrics()
|
168 |
+
|
169 |
+
if not test: f_log.close()
|
metadata/label_map.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ce5be416a8667c9379502eaf8407e6d07bbae03749085190be630bd3b026eb52
|
3 |
+
size 34
|
models/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
models/GraphTransformer.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from torch.autograd import Variable
|
8 |
+
from torch.nn.parameter import Parameter
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.optim as optim
|
12 |
+
|
13 |
+
from .ViT import *
|
14 |
+
from .gcn import GCNBlock
|
15 |
+
|
16 |
+
from torch_geometric.nn import GCNConv, DenseGraphConv, dense_mincut_pool
|
17 |
+
from torch.nn import Linear
|
18 |
+
class Classifier(nn.Module):
|
19 |
+
def __init__(self, n_class):
|
20 |
+
super(Classifier, self).__init__()
|
21 |
+
|
22 |
+
self.n_class = n_class
|
23 |
+
self.embed_dim = 64
|
24 |
+
self.num_layers = 3
|
25 |
+
self.node_cluster_num = 100
|
26 |
+
|
27 |
+
self.transformer = VisionTransformer(num_classes=n_class, embed_dim=self.embed_dim)
|
28 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
29 |
+
self.criterion = nn.CrossEntropyLoss()
|
30 |
+
|
31 |
+
self.bn = 1
|
32 |
+
self.add_self = 1
|
33 |
+
self.normalize_embedding = 1
|
34 |
+
self.conv1 = GCNBlock(512,self.embed_dim,self.bn,self.add_self,self.normalize_embedding,0.,0) # 64->128
|
35 |
+
self.pool1 = Linear(self.embed_dim, self.node_cluster_num) # 100-> 20
|
36 |
+
|
37 |
+
|
38 |
+
def forward(self,node_feat,labels,adj,mask,is_print=False, graphcam_flag=False, to_file=True):
|
39 |
+
# node_feat, labels = self.PrepareFeatureLabel(batch_graph)
|
40 |
+
cls_loss=node_feat.new_zeros(self.num_layers)
|
41 |
+
rank_loss=node_feat.new_zeros(self.num_layers-1)
|
42 |
+
X=node_feat
|
43 |
+
p_t=[]
|
44 |
+
pred_logits=0
|
45 |
+
visualize_tools=[]
|
46 |
+
if labels is not None:
|
47 |
+
visualize_tools1=[labels.cpu()]
|
48 |
+
embeds=0
|
49 |
+
concats=[]
|
50 |
+
|
51 |
+
layer_acc=[]
|
52 |
+
|
53 |
+
X=mask.unsqueeze(2)*X
|
54 |
+
X = self.conv1(X, adj, mask)
|
55 |
+
s = self.pool1(X)
|
56 |
+
|
57 |
+
|
58 |
+
graphcam_tensors = {}
|
59 |
+
|
60 |
+
if graphcam_flag:
|
61 |
+
s_matrix = torch.argmax(s[0], dim=1)
|
62 |
+
if to_file:
|
63 |
+
from os import path
|
64 |
+
os.makedirs('graphcam', exist_ok=True)
|
65 |
+
torch.save(s_matrix, 'graphcam/s_matrix.pt')
|
66 |
+
torch.save(s[0], 'graphcam/s_matrix_ori.pt')
|
67 |
+
|
68 |
+
if path.exists('graphcam/att_1.pt'):
|
69 |
+
os.remove('graphcam/att_1.pt')
|
70 |
+
os.remove('graphcam/att_2.pt')
|
71 |
+
os.remove('graphcam/att_3.pt')
|
72 |
+
|
73 |
+
if not to_file:
|
74 |
+
graphcam_tensors['s_matrix'] = s_matrix
|
75 |
+
graphcam_tensors['s_matrix_ori'] = s[0]
|
76 |
+
|
77 |
+
|
78 |
+
X, adj, mc1, o1 = dense_mincut_pool(X, adj, s, mask)
|
79 |
+
b, _, _ = X.shape
|
80 |
+
cls_token = self.cls_token.repeat(b, 1, 1)
|
81 |
+
X = torch.cat([cls_token, X], dim=1)
|
82 |
+
|
83 |
+
out = self.transformer(X)
|
84 |
+
|
85 |
+
loss = None
|
86 |
+
if labels is not None:
|
87 |
+
# loss
|
88 |
+
loss = self.criterion(out, labels)
|
89 |
+
loss = loss + mc1 + o1
|
90 |
+
# pred
|
91 |
+
pred = out.data.max(1)[1]
|
92 |
+
|
93 |
+
if graphcam_flag:
|
94 |
+
#print('GraphCAM enabled')
|
95 |
+
#print(out.shape)
|
96 |
+
p = F.softmax(out)
|
97 |
+
#print(p.shape)
|
98 |
+
if to_file:
|
99 |
+
torch.save(p, 'graphcam/prob.pt')
|
100 |
+
if not to_file:
|
101 |
+
graphcam_tensors['prob'] = p
|
102 |
+
index = np.argmax(out.cpu().data.numpy(), axis=-1)
|
103 |
+
|
104 |
+
for index_ in range(self.n_class):
|
105 |
+
one_hot = np.zeros((1, out.size()[-1]), dtype=np.float32)
|
106 |
+
one_hot[0, index_] = out[0][index_]
|
107 |
+
one_hot_vector = one_hot
|
108 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
109 |
+
one_hot = torch.sum(one_hot.to( 'cuda' if torch.cuda.is_available() else 'cpu') * out) #!!!!!!!!!!!!!!!!!!!!out-->p
|
110 |
+
self.transformer.zero_grad()
|
111 |
+
one_hot.backward(retain_graph=True)
|
112 |
+
|
113 |
+
kwargs = {"alpha": 1}
|
114 |
+
cam = self.transformer.relprop(torch.tensor(one_hot_vector).to(X.device), method="transformer_attribution", is_ablation=False,
|
115 |
+
start_layer=0, **kwargs)
|
116 |
+
if to_file:
|
117 |
+
torch.save(cam, 'graphcam/cam_{}.pt'.format(index_))
|
118 |
+
if not to_file:
|
119 |
+
graphcam_tensors[f'cam_{index_}'] = cam
|
120 |
+
|
121 |
+
if not to_file:
|
122 |
+
return pred,labels,loss, graphcam_tensors
|
123 |
+
return pred,labels,loss
|
models/ViT.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Vision Transformer (ViT) in PyTorch
|
2 |
+
"""
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from .layers import *
|
7 |
+
import math
|
8 |
+
|
9 |
+
|
10 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
11 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
12 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
13 |
+
def norm_cdf(x):
|
14 |
+
# Computes standard normal cumulative distribution function
|
15 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
16 |
+
|
17 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
18 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
19 |
+
"The distribution of values may be incorrect.",
|
20 |
+
stacklevel=2)
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
# Values are generated by using a truncated uniform distribution and
|
24 |
+
# then using the inverse CDF for the normal distribution.
|
25 |
+
# Get upper and lower cdf values
|
26 |
+
l = norm_cdf((a - mean) / std)
|
27 |
+
u = norm_cdf((b - mean) / std)
|
28 |
+
|
29 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
30 |
+
# [2l-1, 2u-1].
|
31 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
32 |
+
|
33 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
34 |
+
# standard normal
|
35 |
+
tensor.erfinv_()
|
36 |
+
|
37 |
+
# Transform to proper mean, std
|
38 |
+
tensor.mul_(std * math.sqrt(2.))
|
39 |
+
tensor.add_(mean)
|
40 |
+
|
41 |
+
# Clamp to ensure it's in the proper range
|
42 |
+
tensor.clamp_(min=a, max=b)
|
43 |
+
return tensor
|
44 |
+
|
45 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
46 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
47 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
48 |
+
normal distribution. The values are effectively drawn from the
|
49 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
50 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
51 |
+
the bounds. The method used for generating the random values works
|
52 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
53 |
+
Args:
|
54 |
+
tensor: an n-dimensional `torch.Tensor`
|
55 |
+
mean: the mean of the normal distribution
|
56 |
+
std: the standard deviation of the normal distribution
|
57 |
+
a: the minimum cutoff value
|
58 |
+
b: the maximum cutoff value
|
59 |
+
Examples:
|
60 |
+
>>> w = torch.empty(3, 5)
|
61 |
+
>>> nn.init.trunc_normal_(w)
|
62 |
+
"""
|
63 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
64 |
+
|
65 |
+
def _cfg(url='', **kwargs):
|
66 |
+
return {
|
67 |
+
'url': url,
|
68 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
69 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
70 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
71 |
+
**kwargs
|
72 |
+
}
|
73 |
+
|
74 |
+
|
75 |
+
default_cfgs = {
|
76 |
+
# patch models
|
77 |
+
'vit_small_patch16_224': _cfg(
|
78 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
79 |
+
),
|
80 |
+
'vit_base_patch16_224': _cfg(
|
81 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
82 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
83 |
+
),
|
84 |
+
'vit_large_patch16_224': _cfg(
|
85 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
86 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
87 |
+
}
|
88 |
+
|
89 |
+
def compute_rollout_attention(all_layer_matrices, start_layer=0):
|
90 |
+
# adding residual consideration
|
91 |
+
num_tokens = all_layer_matrices[0].shape[1]
|
92 |
+
batch_size = all_layer_matrices[0].shape[0]
|
93 |
+
eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
|
94 |
+
all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
|
95 |
+
# all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
|
96 |
+
# for i in range(len(all_layer_matrices))]
|
97 |
+
joint_attention = all_layer_matrices[start_layer]
|
98 |
+
for i in range(start_layer+1, len(all_layer_matrices)):
|
99 |
+
joint_attention = all_layer_matrices[i].bmm(joint_attention)
|
100 |
+
return joint_attention
|
101 |
+
|
102 |
+
class Mlp(nn.Module):
|
103 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
|
104 |
+
super().__init__()
|
105 |
+
out_features = out_features or in_features
|
106 |
+
hidden_features = hidden_features or in_features
|
107 |
+
self.fc1 = Linear(in_features, hidden_features)
|
108 |
+
self.act = GELU()
|
109 |
+
self.fc2 = Linear(hidden_features, out_features)
|
110 |
+
self.drop = Dropout(drop)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
x = self.fc1(x)
|
114 |
+
x = self.act(x)
|
115 |
+
x = self.drop(x)
|
116 |
+
x = self.fc2(x)
|
117 |
+
x = self.drop(x)
|
118 |
+
return x
|
119 |
+
|
120 |
+
def relprop(self, cam, **kwargs):
|
121 |
+
cam = self.drop.relprop(cam, **kwargs)
|
122 |
+
cam = self.fc2.relprop(cam, **kwargs)
|
123 |
+
cam = self.act.relprop(cam, **kwargs)
|
124 |
+
cam = self.fc1.relprop(cam, **kwargs)
|
125 |
+
return cam
|
126 |
+
|
127 |
+
|
128 |
+
class Attention(nn.Module):
|
129 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
|
130 |
+
super().__init__()
|
131 |
+
self.num_heads = num_heads
|
132 |
+
head_dim = dim // num_heads
|
133 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
134 |
+
self.scale = head_dim ** -0.5
|
135 |
+
|
136 |
+
# A = Q*K^T
|
137 |
+
self.matmul1 = einsum('bhid,bhjd->bhij')
|
138 |
+
# attn = A*V
|
139 |
+
self.matmul2 = einsum('bhij,bhjd->bhid')
|
140 |
+
|
141 |
+
self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
|
142 |
+
self.attn_drop = Dropout(attn_drop)
|
143 |
+
self.proj = Linear(dim, dim)
|
144 |
+
self.proj_drop = Dropout(proj_drop)
|
145 |
+
self.softmax = Softmax(dim=-1)
|
146 |
+
|
147 |
+
self.attn_cam = None
|
148 |
+
self.attn = None
|
149 |
+
self.v = None
|
150 |
+
self.v_cam = None
|
151 |
+
self.attn_gradients = None
|
152 |
+
|
153 |
+
def get_attn(self):
|
154 |
+
return self.attn
|
155 |
+
|
156 |
+
def save_attn(self, attn):
|
157 |
+
self.attn = attn
|
158 |
+
|
159 |
+
def save_attn_cam(self, cam):
|
160 |
+
self.attn_cam = cam
|
161 |
+
|
162 |
+
def get_attn_cam(self):
|
163 |
+
return self.attn_cam
|
164 |
+
|
165 |
+
def get_v(self):
|
166 |
+
return self.v
|
167 |
+
|
168 |
+
def save_v(self, v):
|
169 |
+
self.v = v
|
170 |
+
|
171 |
+
def save_v_cam(self, cam):
|
172 |
+
self.v_cam = cam
|
173 |
+
|
174 |
+
def get_v_cam(self):
|
175 |
+
return self.v_cam
|
176 |
+
|
177 |
+
def save_attn_gradients(self, attn_gradients):
|
178 |
+
self.attn_gradients = attn_gradients
|
179 |
+
|
180 |
+
def get_attn_gradients(self):
|
181 |
+
return self.attn_gradients
|
182 |
+
|
183 |
+
def forward(self, x):
|
184 |
+
b, n, _, h = *x.shape, self.num_heads
|
185 |
+
qkv = self.qkv(x)
|
186 |
+
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
|
187 |
+
|
188 |
+
self.save_v(v)
|
189 |
+
|
190 |
+
dots = self.matmul1([q, k]) * self.scale
|
191 |
+
|
192 |
+
attn = self.softmax(dots)
|
193 |
+
attn = self.attn_drop(attn)
|
194 |
+
|
195 |
+
# Get attention
|
196 |
+
if False:
|
197 |
+
from os import path
|
198 |
+
if not path.exists('att_1.pt'):
|
199 |
+
torch.save(attn, 'att_1.pt')
|
200 |
+
elif not path.exists('att_2.pt'):
|
201 |
+
torch.save(attn, 'att_2.pt')
|
202 |
+
else:
|
203 |
+
torch.save(attn, 'att_3.pt')
|
204 |
+
|
205 |
+
#comment in training
|
206 |
+
if x.requires_grad:
|
207 |
+
self.save_attn(attn)
|
208 |
+
attn.register_hook(self.save_attn_gradients)
|
209 |
+
|
210 |
+
out = self.matmul2([attn, v])
|
211 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
212 |
+
|
213 |
+
out = self.proj(out)
|
214 |
+
out = self.proj_drop(out)
|
215 |
+
return out
|
216 |
+
|
217 |
+
def relprop(self, cam, **kwargs):
|
218 |
+
cam = self.proj_drop.relprop(cam, **kwargs)
|
219 |
+
cam = self.proj.relprop(cam, **kwargs)
|
220 |
+
cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads)
|
221 |
+
|
222 |
+
# attn = A*V
|
223 |
+
(cam1, cam_v)= self.matmul2.relprop(cam, **kwargs)
|
224 |
+
cam1 /= 2
|
225 |
+
cam_v /= 2
|
226 |
+
|
227 |
+
self.save_v_cam(cam_v)
|
228 |
+
self.save_attn_cam(cam1)
|
229 |
+
|
230 |
+
cam1 = self.attn_drop.relprop(cam1, **kwargs)
|
231 |
+
cam1 = self.softmax.relprop(cam1, **kwargs)
|
232 |
+
|
233 |
+
# A = Q*K^T
|
234 |
+
(cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
|
235 |
+
cam_q /= 2
|
236 |
+
cam_k /= 2
|
237 |
+
|
238 |
+
cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads)
|
239 |
+
|
240 |
+
return self.qkv.relprop(cam_qkv, **kwargs)
|
241 |
+
|
242 |
+
|
243 |
+
class Block(nn.Module):
|
244 |
+
|
245 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
|
246 |
+
super().__init__()
|
247 |
+
self.norm1 = LayerNorm(dim, eps=1e-6)
|
248 |
+
self.attn = Attention(
|
249 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
250 |
+
self.norm2 = LayerNorm(dim, eps=1e-6)
|
251 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
252 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
|
253 |
+
|
254 |
+
self.add1 = Add()
|
255 |
+
self.add2 = Add()
|
256 |
+
self.clone1 = Clone()
|
257 |
+
self.clone2 = Clone()
|
258 |
+
|
259 |
+
def forward(self, x):
|
260 |
+
x1, x2 = self.clone1(x, 2)
|
261 |
+
x = self.add1([x1, self.attn(self.norm1(x2))])
|
262 |
+
x1, x2 = self.clone2(x, 2)
|
263 |
+
x = self.add2([x1, self.mlp(self.norm2(x2))])
|
264 |
+
return x
|
265 |
+
|
266 |
+
def relprop(self, cam, **kwargs):
|
267 |
+
(cam1, cam2) = self.add2.relprop(cam, **kwargs)
|
268 |
+
cam2 = self.mlp.relprop(cam2, **kwargs)
|
269 |
+
cam2 = self.norm2.relprop(cam2, **kwargs)
|
270 |
+
cam = self.clone2.relprop((cam1, cam2), **kwargs)
|
271 |
+
|
272 |
+
(cam1, cam2) = self.add1.relprop(cam, **kwargs)
|
273 |
+
cam2 = self.attn.relprop(cam2, **kwargs)
|
274 |
+
cam2 = self.norm1.relprop(cam2, **kwargs)
|
275 |
+
cam = self.clone1.relprop((cam1, cam2), **kwargs)
|
276 |
+
return cam
|
277 |
+
|
278 |
+
class VisionTransformer(nn.Module):
|
279 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
280 |
+
"""
|
281 |
+
def __init__(self, num_classes=2, embed_dim=64, depth=3,
|
282 |
+
num_heads=8, mlp_ratio=2., qkv_bias=False, mlp_head=False, drop_rate=0., attn_drop_rate=0.):
|
283 |
+
super().__init__()
|
284 |
+
self.num_classes = num_classes
|
285 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
286 |
+
|
287 |
+
self.blocks = nn.ModuleList([
|
288 |
+
Block(
|
289 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
290 |
+
drop=drop_rate, attn_drop=attn_drop_rate)
|
291 |
+
for i in range(depth)])
|
292 |
+
|
293 |
+
self.norm = LayerNorm(embed_dim)
|
294 |
+
if mlp_head:
|
295 |
+
# paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
|
296 |
+
self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
|
297 |
+
else:
|
298 |
+
# with a single Linear layer as head, the param count within rounding of paper
|
299 |
+
self.head = Linear(embed_dim, num_classes)
|
300 |
+
|
301 |
+
#self.apply(self._init_weights)
|
302 |
+
|
303 |
+
self.pool = IndexSelect()
|
304 |
+
self.add = Add()
|
305 |
+
|
306 |
+
self.inp_grad = None
|
307 |
+
|
308 |
+
def save_inp_grad(self,grad):
|
309 |
+
self.inp_grad = grad
|
310 |
+
|
311 |
+
def get_inp_grad(self):
|
312 |
+
return self.inp_grad
|
313 |
+
|
314 |
+
|
315 |
+
def _init_weights(self, m):
|
316 |
+
if isinstance(m, nn.Linear):
|
317 |
+
trunc_normal_(m.weight, std=.02)
|
318 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
319 |
+
nn.init.constant_(m.bias, 0)
|
320 |
+
elif isinstance(m, nn.LayerNorm):
|
321 |
+
nn.init.constant_(m.bias, 0)
|
322 |
+
nn.init.constant_(m.weight, 1.0)
|
323 |
+
|
324 |
+
@property
|
325 |
+
def no_weight_decay(self):
|
326 |
+
return {'pos_embed', 'cls_token'}
|
327 |
+
|
328 |
+
def forward(self, x):
|
329 |
+
if x.requires_grad:
|
330 |
+
x.register_hook(self.save_inp_grad) #comment it in train
|
331 |
+
|
332 |
+
for blk in self.blocks:
|
333 |
+
x = blk(x)
|
334 |
+
|
335 |
+
x = self.norm(x)
|
336 |
+
x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
|
337 |
+
x = x.squeeze(1)
|
338 |
+
x = self.head(x)
|
339 |
+
return x
|
340 |
+
|
341 |
+
def relprop(self, cam=None,method="transformer_attribution", is_ablation=False, start_layer=0, **kwargs):
|
342 |
+
# print(kwargs)
|
343 |
+
# print("conservation 1", cam.sum())
|
344 |
+
cam = self.head.relprop(cam, **kwargs)
|
345 |
+
cam = cam.unsqueeze(1)
|
346 |
+
cam = self.pool.relprop(cam, **kwargs)
|
347 |
+
cam = self.norm.relprop(cam, **kwargs)
|
348 |
+
for blk in reversed(self.blocks):
|
349 |
+
cam = blk.relprop(cam, **kwargs)
|
350 |
+
|
351 |
+
# print("conservation 2", cam.sum())
|
352 |
+
# print("min", cam.min())
|
353 |
+
|
354 |
+
if method == "full":
|
355 |
+
(cam, _) = self.add.relprop(cam, **kwargs)
|
356 |
+
cam = cam[:, 1:]
|
357 |
+
cam = self.patch_embed.relprop(cam, **kwargs)
|
358 |
+
# sum on channels
|
359 |
+
cam = cam.sum(dim=1)
|
360 |
+
return cam
|
361 |
+
|
362 |
+
elif method == "rollout":
|
363 |
+
# cam rollout
|
364 |
+
attn_cams = []
|
365 |
+
for blk in self.blocks:
|
366 |
+
attn_heads = blk.attn.get_attn_cam().clamp(min=0)
|
367 |
+
avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
|
368 |
+
attn_cams.append(avg_heads)
|
369 |
+
cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
|
370 |
+
cam = cam[:, 0, 1:]
|
371 |
+
return cam
|
372 |
+
|
373 |
+
# our method, method name grad is legacy
|
374 |
+
elif method == "transformer_attribution" or method == "grad":
|
375 |
+
cams = []
|
376 |
+
for blk in self.blocks:
|
377 |
+
grad = blk.attn.get_attn_gradients()
|
378 |
+
cam = blk.attn.get_attn_cam()
|
379 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
380 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
381 |
+
cam = grad * cam
|
382 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
383 |
+
cams.append(cam.unsqueeze(0))
|
384 |
+
rollout = compute_rollout_attention(cams, start_layer=start_layer)
|
385 |
+
cam = rollout[:, 0, 1:]
|
386 |
+
return cam
|
387 |
+
|
388 |
+
elif method == "last_layer":
|
389 |
+
cam = self.blocks[-1].attn.get_attn_cam()
|
390 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
391 |
+
if is_ablation:
|
392 |
+
grad = self.blocks[-1].attn.get_attn_gradients()
|
393 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
394 |
+
cam = grad * cam
|
395 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
396 |
+
cam = cam[0, 1:]
|
397 |
+
return cam
|
398 |
+
|
399 |
+
elif method == "last_layer_attn":
|
400 |
+
cam = self.blocks[-1].attn.get_attn()
|
401 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
402 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
403 |
+
cam = cam[0, 1:]
|
404 |
+
return cam
|
405 |
+
|
406 |
+
elif method == "second_layer":
|
407 |
+
cam = self.blocks[1].attn.get_attn_cam()
|
408 |
+
cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
|
409 |
+
if is_ablation:
|
410 |
+
grad = self.blocks[1].attn.get_attn_gradients()
|
411 |
+
grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
|
412 |
+
cam = grad * cam
|
413 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
414 |
+
cam = cam[0, 1:]
|
415 |
+
return cam
|
models/__init__.py
ADDED
File without changes
|
models/__pycache__/GraphTransformer.cpython-38.pyc
ADDED
Binary file (3.35 kB). View file
|
|
models/__pycache__/ViT.cpython-38.pyc
ADDED
Binary file (12.5 kB). View file
|
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (168 Bytes). View file
|
|
models/__pycache__/gcn.cpython-38.pyc
ADDED
Binary file (9.61 kB). View file
|
|
models/__pycache__/layers.cpython-38.pyc
ADDED
Binary file (9.93 kB). View file
|
|
models/__pycache__/weight_init.cpython-38.pyc
ADDED
Binary file (1.72 kB). View file
|
|
models/gcn.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
torch.set_printoptions(precision=2,threshold=float('inf'))
|
10 |
+
|
11 |
+
class AGCNBlock(nn.Module):
|
12 |
+
def __init__(self,input_dim,hidden_dim,gcn_layer=2,dropout=0.0,relu=0):
|
13 |
+
super(AGCNBlock,self).__init__()
|
14 |
+
if dropout > 0.001:
|
15 |
+
self.dropout_layer = nn.Dropout(p=dropout)
|
16 |
+
self.sort = 'sort'
|
17 |
+
self.model='agcn'
|
18 |
+
self.gcns=nn.ModuleList()
|
19 |
+
self.bn = 0
|
20 |
+
self.add_self = 1
|
21 |
+
self.normalize_embedding = 1
|
22 |
+
self.gcns.append(GCNBlock(input_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu))
|
23 |
+
self.pool = 'mean'
|
24 |
+
self.tau = 1.
|
25 |
+
self.lamda = 1.
|
26 |
+
|
27 |
+
for i in range(gcn_layer-1):
|
28 |
+
if i==gcn_layer-2 and (not 1):
|
29 |
+
self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,0))
|
30 |
+
else:
|
31 |
+
self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu))
|
32 |
+
|
33 |
+
if self.model=='diffpool':
|
34 |
+
self.pool_gcns=nn.ModuleList()
|
35 |
+
tmp=input_dim
|
36 |
+
self.diffpool_k=200
|
37 |
+
for i in range(3):
|
38 |
+
self.pool_gcns.append(GCNBlock(tmp,200,0,0,0,dropout,relu))
|
39 |
+
tmp=200
|
40 |
+
|
41 |
+
self.w_a=nn.Parameter(torch.zeros(1,hidden_dim,1))
|
42 |
+
self.w_b=nn.Parameter(torch.zeros(1,hidden_dim,1))
|
43 |
+
torch.nn.init.normal_(self.w_a)
|
44 |
+
torch.nn.init.uniform_(self.w_b,-1,1)
|
45 |
+
|
46 |
+
self.pass_dim=hidden_dim
|
47 |
+
|
48 |
+
if self.pool=='mean':
|
49 |
+
self.pool=self.mean_pool
|
50 |
+
elif self.pool=='max':
|
51 |
+
self.pool=self.max_pool
|
52 |
+
elif self.pool=='sum':
|
53 |
+
self.pool=self.sum_pool
|
54 |
+
|
55 |
+
self.softmax='global'
|
56 |
+
if self.softmax=='gcn':
|
57 |
+
self.att_gcn=GCNBlock(2,1,0,0,dropout,relu)
|
58 |
+
self.khop=1
|
59 |
+
self.adj_norm='none'
|
60 |
+
|
61 |
+
self.filt_percent=0.25 #default 0.5
|
62 |
+
self.eps=1e-10
|
63 |
+
|
64 |
+
self.tau_config=1
|
65 |
+
if 1==-1.:
|
66 |
+
self.tau=nn.Parameter(torch.tensor(1),requires_grad=False)
|
67 |
+
elif 1==-2.:
|
68 |
+
self.tau_fc=nn.Linear(hidden_dim,1)
|
69 |
+
torch.nn.init.constant_(self.tau_fc.bias,1)
|
70 |
+
torch.nn.init.xavier_normal_(self.tau_fc.weight.t())
|
71 |
+
else:
|
72 |
+
self.tau=nn.Parameter(torch.tensor(self.tau))
|
73 |
+
self.lamda1=nn.Parameter(torch.tensor(self.lamda))
|
74 |
+
self.lamda2=nn.Parameter(torch.tensor(self.lamda))
|
75 |
+
|
76 |
+
self.att_norm=0
|
77 |
+
|
78 |
+
self.dnorm=0
|
79 |
+
self.dnorm_coe=1
|
80 |
+
|
81 |
+
self.att_out=0
|
82 |
+
self.single_att=0
|
83 |
+
|
84 |
+
|
85 |
+
def forward(self,X,adj,mask,is_print=False):
|
86 |
+
'''
|
87 |
+
input:
|
88 |
+
X: node input features , [batch,node_num,input_dim],dtype=float
|
89 |
+
adj: adj matrix, [batch,node_num,node_num], dtype=float
|
90 |
+
mask: mask for nodes, [batch,node_num]
|
91 |
+
outputs:
|
92 |
+
out:unormalized classification prob, [batch,hidden_dim]
|
93 |
+
H: batch of node hidden features, [batch,node_num,pass_dim]
|
94 |
+
new_adj: pooled new adj matrix, [batch, k_max, k_max]
|
95 |
+
new_mask: [batch, k_max]
|
96 |
+
'''
|
97 |
+
hidden=X
|
98 |
+
#adj = adj.float()
|
99 |
+
# print('input size:')
|
100 |
+
# print(hidden.shape)
|
101 |
+
|
102 |
+
is_print1=is_print2=is_print
|
103 |
+
if adj.shape[-1]>100:
|
104 |
+
is_print1=False
|
105 |
+
|
106 |
+
for gcn in self.gcns:
|
107 |
+
hidden=gcn(hidden,adj,mask)
|
108 |
+
# print('gcn:')
|
109 |
+
# print(hidden.shape)
|
110 |
+
# print('mask:')
|
111 |
+
# print(mask.unsqueeze(2).shape)
|
112 |
+
# print(mask.sum(dim=1))
|
113 |
+
|
114 |
+
hidden=mask.unsqueeze(2)*hidden
|
115 |
+
# print(hidden[0][0])
|
116 |
+
# print(hidden[0][-1])
|
117 |
+
|
118 |
+
if self.model=='unet':
|
119 |
+
att=torch.matmul(hidden,self.w_a).squeeze()
|
120 |
+
att=att/torch.sqrt((self.w_a.squeeze(2)**2).sum(dim=1,keepdim=True))
|
121 |
+
elif self.model=='agcn':
|
122 |
+
if self.softmax=='global' or self.softmax=='mix':
|
123 |
+
if False:
|
124 |
+
dgree_w = torch.sum(adj, dim=2) / torch.sum(adj, dim=2).max(1, keepdim=True)[0]
|
125 |
+
att_a=torch.matmul(hidden,self.w_a).squeeze()*dgree_w+(mask-1)*1e10
|
126 |
+
else:
|
127 |
+
att_a=torch.matmul(hidden,self.w_a).squeeze()+(mask-1)*1e10
|
128 |
+
# print(att_a[0][:10])
|
129 |
+
# print(att_a[0][-10:-1])
|
130 |
+
att_a_1=att_a=torch.nn.functional.softmax(att_a,dim=1)
|
131 |
+
# print(att_a[0][:10])
|
132 |
+
# print(att_a[0][-10:-1])
|
133 |
+
|
134 |
+
if self.dnorm:
|
135 |
+
scale=mask.sum(dim=1,keepdim=True)/self.dnorm_coe
|
136 |
+
att_a=scale*att_a
|
137 |
+
if self.softmax=='neibor' or self.softmax=='mix':
|
138 |
+
att_b=torch.matmul(hidden,self.w_b).squeeze()+(mask-1)*1e10
|
139 |
+
att_b_max,_=att_b.max(dim=1,keepdim=True)
|
140 |
+
if self.tau_config!=-2:
|
141 |
+
att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau))
|
142 |
+
else:
|
143 |
+
att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau_fc(self.pool(hidden,mask))))
|
144 |
+
denom=att_b.unsqueeze(2)
|
145 |
+
for _ in range(self.khop):
|
146 |
+
denom=torch.matmul(adj,denom)
|
147 |
+
denom=denom.squeeze()+self.eps
|
148 |
+
att_b=(att_b*torch.diagonal(adj,0,1,2))/denom
|
149 |
+
if self.dnorm:
|
150 |
+
if self.adj_norm=='diag':
|
151 |
+
diag_scale=mask/(torch.diagonal(adj,0,1,2)+self.eps)
|
152 |
+
elif self.adj_norm=='none':
|
153 |
+
diag_scale=adj.sum(dim=1)
|
154 |
+
att_b=att_b*diag_scale
|
155 |
+
att_b=att_b*mask
|
156 |
+
|
157 |
+
if self.softmax=='global':
|
158 |
+
att=att_a
|
159 |
+
elif self.softmax=='neibor' or self.softmax=='hardnei':
|
160 |
+
att=att_b
|
161 |
+
elif self.softmax=='mix':
|
162 |
+
att=att_a*torch.abs(self.lamda1)+att_b*torch.abs(self.lamda2)
|
163 |
+
# print('att:')
|
164 |
+
# print(att.shape)
|
165 |
+
Z=hidden
|
166 |
+
|
167 |
+
if self.model=='unet':
|
168 |
+
Z=torch.tanh(att.unsqueeze(2))*Z
|
169 |
+
elif self.model=='agcn':
|
170 |
+
if self.single_att:
|
171 |
+
Z=Z
|
172 |
+
else:
|
173 |
+
Z=att.unsqueeze(2)*Z
|
174 |
+
# print('Z shape')
|
175 |
+
# print(Z.shape)
|
176 |
+
k_max=int(math.ceil(self.filt_percent*adj.shape[-1]))
|
177 |
+
# print('k_max')
|
178 |
+
# print(k_max)
|
179 |
+
if self.model=='diffpool':
|
180 |
+
k_max=min(k_max,self.diffpool_k)
|
181 |
+
|
182 |
+
k_list=[int(math.ceil(self.filt_percent*x)) for x in mask.sum(dim=1).tolist()]
|
183 |
+
# print('k_list')
|
184 |
+
# print(k_list)
|
185 |
+
if self.model!='diffpool':
|
186 |
+
if self.sort=='sample':
|
187 |
+
att_samp = att * mask
|
188 |
+
att_samp = (att_samp/att_samp.sum(1)).detach().cpu().numpy()
|
189 |
+
top_index = ()
|
190 |
+
for i in range(att.size(0)):
|
191 |
+
top_index = (torch.LongTensor(np.random.choice(att_samp.size(1), k_max, att_samp[i])) ,)
|
192 |
+
top_index = torch.stack(top_index,1)
|
193 |
+
elif self.sort=='random_sample':
|
194 |
+
top_index = torch.LongTensor(att.size(0), k_max)*0
|
195 |
+
for i in range(att.size(0)):
|
196 |
+
top_index[i,0:k_list[i]] = torch.randperm(int(mask[i].sum().item()))[0:k_list[i]]
|
197 |
+
else: #sort
|
198 |
+
_,top_index=torch.topk(att,k_max,dim=1)
|
199 |
+
# print('top_index')
|
200 |
+
# print(top_index)
|
201 |
+
# print(len(top_index[0]))
|
202 |
+
new_mask=X.new_zeros(X.shape[0],k_max)
|
203 |
+
# print('new_mask')
|
204 |
+
# print(new_mask.shape)
|
205 |
+
visualize_tools=None
|
206 |
+
if self.model=='unet':
|
207 |
+
for i,k in enumerate(k_list):
|
208 |
+
for j in range(int(k),k_max):
|
209 |
+
top_index[i][j]=adj.shape[-1]-1
|
210 |
+
new_mask[i][j]=-1.
|
211 |
+
new_mask=new_mask+1
|
212 |
+
top_index,_=torch.sort(top_index,dim=1)
|
213 |
+
assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1])
|
214 |
+
for i,x in enumerate(top_index):
|
215 |
+
assign_m[i]=torch.index_select(adj[i],0,x)
|
216 |
+
new_adj=X.new_zeros(X.shape[0],k_max,k_max)
|
217 |
+
H=Z.new_zeros(Z.shape[0],k_max,Z.shape[-1])
|
218 |
+
for i,x in enumerate(top_index):
|
219 |
+
new_adj[i]=torch.index_select(assign_m[i],1,x)
|
220 |
+
H[i]=torch.index_select(Z[i],0,x)
|
221 |
+
|
222 |
+
elif self.model=='agcn':
|
223 |
+
assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1])
|
224 |
+
# print('assign_m.shape')
|
225 |
+
# print(assign_m.shape)
|
226 |
+
for i,k in enumerate(k_list):
|
227 |
+
#print('top_index[i][j]')
|
228 |
+
for j in range(int(k)):
|
229 |
+
#print(str(top_index[i][j].item())+' ', end='')
|
230 |
+
assign_m[i][j]=adj[i][top_index[i][j]]
|
231 |
+
#print(assign_m[i][j])
|
232 |
+
new_mask[i][j]=1.
|
233 |
+
|
234 |
+
assign_m=assign_m/(assign_m.sum(dim=1,keepdim=True)+self.eps)
|
235 |
+
H=torch.matmul(assign_m,Z)
|
236 |
+
# print('H')
|
237 |
+
# print(H.shape)
|
238 |
+
new_adj=torch.matmul(torch.matmul(assign_m,adj),torch.transpose(assign_m,1,2))
|
239 |
+
# print(torch.matmul(assign_m,adj).shape)
|
240 |
+
# print('new_adj:')
|
241 |
+
# print(new_adj.shape)
|
242 |
+
|
243 |
+
elif self.model=='diffpool':
|
244 |
+
hidden1=X
|
245 |
+
for gcn in self.pool_gcns:
|
246 |
+
hidden1=gcn(hidden1,adj,mask)
|
247 |
+
assign_m=X.new_ones(X.shape[0],X.shape[1],k_max)*(-100000000.)
|
248 |
+
for i,x in enumerate(hidden1):
|
249 |
+
k=min(k_list[i],k_max)
|
250 |
+
assign_m[i,:,0:k]=hidden1[i,:,0:k]
|
251 |
+
for j in range(int(k)):
|
252 |
+
new_mask[i][j]=1.
|
253 |
+
|
254 |
+
assign_m=torch.nn.functional.softmax(assign_m,dim=2)*mask.unsqueeze(2)
|
255 |
+
assign_m_t=torch.transpose(assign_m,1,2)
|
256 |
+
new_adj=torch.matmul(torch.matmul(assign_m_t,adj),assign_m)
|
257 |
+
H=torch.matmul(assign_m_t,Z)
|
258 |
+
# print('pool')
|
259 |
+
if self.att_out and self.model=='agcn':
|
260 |
+
if self.softmax=='global':
|
261 |
+
out=self.pool(att_a_1.unsqueeze(2)*hidden,mask)
|
262 |
+
elif self.softmax=='neibor':
|
263 |
+
att_b_sum=att_b.sum(dim=1,keepdim=True)
|
264 |
+
out=self.pool((att_b/(att_b_sum+self.eps)).unsqueeze(2)*hidden,mask)
|
265 |
+
else:
|
266 |
+
# print('hidden.shape')
|
267 |
+
# print(hidden.shape)
|
268 |
+
out=self.pool(hidden,mask)
|
269 |
+
# print('out shape')
|
270 |
+
# print(out.shape)
|
271 |
+
|
272 |
+
if self.adj_norm=='tanh' or self.adj_norm=='mix':
|
273 |
+
new_adj=torch.tanh(new_adj)
|
274 |
+
elif self.adj_norm=='diag' or self.adj_norm=='mix':
|
275 |
+
diag_elem=torch.pow(new_adj.sum(dim=2)+self.eps,-0.5)
|
276 |
+
diag=new_adj.new_zeros(new_adj.shape)
|
277 |
+
for i,x in enumerate(diag_elem):
|
278 |
+
diag[i]=torch.diagflat(x)
|
279 |
+
new_adj=torch.matmul(torch.matmul(diag,new_adj),diag)
|
280 |
+
|
281 |
+
visualize_tools=[]
|
282 |
+
'''
|
283 |
+
if (not self.training) and is_print1:
|
284 |
+
print('**********************************')
|
285 |
+
print('node_feat:',X.type(),X.shape)
|
286 |
+
print(X)
|
287 |
+
if self.model!='diffpool':
|
288 |
+
print('**********************************')
|
289 |
+
print('att:',att.type(),att.shape)
|
290 |
+
print(att)
|
291 |
+
print('**********************************')
|
292 |
+
print('top_index:',top_index.type(),top_index.shape)
|
293 |
+
print(top_index)
|
294 |
+
print('**********************************')
|
295 |
+
print('adj:',adj.type(),adj.shape)
|
296 |
+
print(adj)
|
297 |
+
print('**********************************')
|
298 |
+
print('assign_m:',assign_m.type(),assign_m.shape)
|
299 |
+
print(assign_m)
|
300 |
+
print('**********************************')
|
301 |
+
print('new_adj:',new_adj.type(),new_adj.shape)
|
302 |
+
print(new_adj)
|
303 |
+
print('**********************************')
|
304 |
+
print('new_mask:',new_mask.type(),new_mask.shape)
|
305 |
+
print(new_mask)
|
306 |
+
'''
|
307 |
+
#visualization
|
308 |
+
from os import path
|
309 |
+
if not path.exists('att_1.pt'):
|
310 |
+
torch.save(att[0], 'att_1.pt')
|
311 |
+
torch.save(top_index[0], 'att_ind1.pt')
|
312 |
+
elif not path.exists('att_2.pt'):
|
313 |
+
torch.save(att[0], 'att_2.pt')
|
314 |
+
torch.save(top_index[0], 'att_ind2.pt')
|
315 |
+
else:
|
316 |
+
torch.save(att[0], 'att_3.pt')
|
317 |
+
torch.save(top_index[0], 'att_ind3.pt')
|
318 |
+
|
319 |
+
if (not self.training) and is_print2:
|
320 |
+
if self.model!='diffpool':
|
321 |
+
visualize_tools.append(att[0])
|
322 |
+
visualize_tools.append(top_index[0])
|
323 |
+
visualize_tools.append(new_adj[0])
|
324 |
+
visualize_tools.append(new_mask.sum())
|
325 |
+
# print('**********************************')
|
326 |
+
return out,H,new_adj,new_mask,visualize_tools
|
327 |
+
|
328 |
+
def mean_pool(self,x,mask):
|
329 |
+
return x.sum(dim=1)/(self.eps+mask.sum(dim=1,keepdim=True))
|
330 |
+
|
331 |
+
def sum_pool(self,x,mask):
|
332 |
+
return x.sum(dim=1)
|
333 |
+
|
334 |
+
@staticmethod
|
335 |
+
def max_pool(x,mask):
|
336 |
+
#output: [batch,x.shape[2]]
|
337 |
+
m=(mask-1)*1e10
|
338 |
+
r,_=(x+m.unsqueeze(2)).max(dim=1)
|
339 |
+
return r
|
340 |
+
# GCN basic operation
|
341 |
+
class GCNBlock(nn.Module):
|
342 |
+
def __init__(self, input_dim, output_dim, bn=0,add_self=0, normalize_embedding=0,
|
343 |
+
dropout=0.0,relu=0, bias=True):
|
344 |
+
super(GCNBlock,self).__init__()
|
345 |
+
self.add_self = add_self
|
346 |
+
self.dropout = dropout
|
347 |
+
self.relu=relu
|
348 |
+
self.bn=bn
|
349 |
+
if dropout > 0.001:
|
350 |
+
self.dropout_layer = nn.Dropout(p=dropout)
|
351 |
+
if self.bn:
|
352 |
+
self.bn_layer = torch.nn.BatchNorm1d(output_dim)
|
353 |
+
|
354 |
+
self.normalize_embedding = normalize_embedding
|
355 |
+
self.input_dim = input_dim
|
356 |
+
self.output_dim = output_dim
|
357 |
+
|
358 |
+
self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') )
|
359 |
+
torch.nn.init.xavier_normal_(self.weight)
|
360 |
+
if bias:
|
361 |
+
self.bias = nn.Parameter(torch.zeros(output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') )
|
362 |
+
else:
|
363 |
+
self.bias = None
|
364 |
+
|
365 |
+
def forward(self, x, adj, mask):
|
366 |
+
y = torch.matmul(adj, x)
|
367 |
+
if self.add_self:
|
368 |
+
y += x
|
369 |
+
y = torch.matmul(y,self.weight)
|
370 |
+
if self.bias is not None:
|
371 |
+
y = y + self.bias
|
372 |
+
if self.normalize_embedding:
|
373 |
+
y = F.normalize(y, p=2, dim=2)
|
374 |
+
if self.bn:
|
375 |
+
index=mask.sum(dim=1).long().tolist()
|
376 |
+
bn_tensor_bf=mask.new_zeros((sum(index),y.shape[2]))
|
377 |
+
bn_tensor_af=mask.new_zeros(*y.shape)
|
378 |
+
start_index=[]
|
379 |
+
ssum=0
|
380 |
+
for i in range(x.shape[0]):
|
381 |
+
start_index.append(ssum)
|
382 |
+
ssum+=index[i]
|
383 |
+
start_index.append(ssum)
|
384 |
+
for i in range(x.shape[0]):
|
385 |
+
bn_tensor_bf[start_index[i]:start_index[i+1]]=y[i,0:index[i]]
|
386 |
+
bn_tensor_bf=self.bn_layer(bn_tensor_bf)
|
387 |
+
for i in range(x.shape[0]):
|
388 |
+
bn_tensor_af[i,0:index[i]]=bn_tensor_bf[start_index[i]:start_index[i+1]]
|
389 |
+
y=bn_tensor_af
|
390 |
+
if self.dropout > 0.001:
|
391 |
+
y = self.dropout_layer(y)
|
392 |
+
if self.relu=='relu':
|
393 |
+
y=torch.nn.functional.relu(y)
|
394 |
+
print('hahah')
|
395 |
+
elif self.relu=='lrelu':
|
396 |
+
y=torch.nn.functional.leaky_relu(y,0.1)
|
397 |
+
return y
|
398 |
+
|
399 |
+
#experimental function, untested
|
400 |
+
class masked_batchnorm(nn.Module):
|
401 |
+
def __init__(self,feat_dim,epsilon=1e-10):
|
402 |
+
super().__init__()
|
403 |
+
self.alpha=nn.Parameter(torch.ones(feat_dim))
|
404 |
+
self.beta=nn.Parameter(torch.zeros(feat_dim))
|
405 |
+
self.eps=epsilon
|
406 |
+
|
407 |
+
def forward(self,x,mask):
|
408 |
+
'''
|
409 |
+
x: node feat, [batch,node_num,feat_dim]
|
410 |
+
mask: [batch,node_num]
|
411 |
+
'''
|
412 |
+
mask1 = mask.unsqueeze(2)
|
413 |
+
mask_sum = mask.sum()
|
414 |
+
mean = x.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum)
|
415 |
+
temp = (x - mean)**2
|
416 |
+
temp = temp*mask1
|
417 |
+
var = temp.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum)
|
418 |
+
rstd = torch.rsqrt(var+self.eps)
|
419 |
+
x=(x-mean)*rstd
|
420 |
+
return ((x*self.alpha) + self.beta)*mask1
|
models/layers.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
__all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
|
6 |
+
'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
|
7 |
+
'LayerNorm', 'AddEye']
|
8 |
+
|
9 |
+
|
10 |
+
def safe_divide(a, b):
|
11 |
+
den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
|
12 |
+
den = den + den.eq(0).type(den.type()) * 1e-9
|
13 |
+
return a / den * b.ne(0).type(b.type())
|
14 |
+
|
15 |
+
|
16 |
+
def forward_hook(self, input, output):
|
17 |
+
if type(input[0]) in (list, tuple):
|
18 |
+
self.X = []
|
19 |
+
for i in input[0]:
|
20 |
+
x = i.detach()
|
21 |
+
x.requires_grad = True
|
22 |
+
self.X.append(x)
|
23 |
+
else:
|
24 |
+
self.X = input[0].detach()
|
25 |
+
self.X.requires_grad = True
|
26 |
+
|
27 |
+
self.Y = output
|
28 |
+
|
29 |
+
|
30 |
+
def backward_hook(self, grad_input, grad_output):
|
31 |
+
self.grad_input = grad_input
|
32 |
+
self.grad_output = grad_output
|
33 |
+
|
34 |
+
|
35 |
+
class RelProp(nn.Module):
|
36 |
+
def __init__(self):
|
37 |
+
super(RelProp, self).__init__()
|
38 |
+
# if not self.training:
|
39 |
+
self.register_forward_hook(forward_hook)
|
40 |
+
|
41 |
+
def gradprop(self, Z, X, S):
|
42 |
+
C = torch.autograd.grad(Z, X, S, retain_graph=True)
|
43 |
+
return C
|
44 |
+
|
45 |
+
def relprop(self, R, alpha):
|
46 |
+
return R
|
47 |
+
|
48 |
+
class RelPropSimple(RelProp):
|
49 |
+
def relprop(self, R, alpha):
|
50 |
+
Z = self.forward(self.X)
|
51 |
+
S = safe_divide(R, Z)
|
52 |
+
C = self.gradprop(Z, self.X, S)
|
53 |
+
|
54 |
+
if torch.is_tensor(self.X) == False:
|
55 |
+
outputs = []
|
56 |
+
outputs.append(self.X[0] * C[0])
|
57 |
+
outputs.append(self.X[1] * C[1])
|
58 |
+
else:
|
59 |
+
outputs = self.X * (C[0])
|
60 |
+
return outputs
|
61 |
+
|
62 |
+
class AddEye(RelPropSimple):
|
63 |
+
# input of shape B, C, seq_len, seq_len
|
64 |
+
def forward(self, input):
|
65 |
+
return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
|
66 |
+
|
67 |
+
class ReLU(nn.ReLU, RelProp):
|
68 |
+
pass
|
69 |
+
|
70 |
+
class GELU(nn.GELU, RelProp):
|
71 |
+
pass
|
72 |
+
|
73 |
+
class Softmax(nn.Softmax, RelProp):
|
74 |
+
pass
|
75 |
+
|
76 |
+
class LayerNorm(nn.LayerNorm, RelProp):
|
77 |
+
pass
|
78 |
+
|
79 |
+
class Dropout(nn.Dropout, RelProp):
|
80 |
+
pass
|
81 |
+
|
82 |
+
|
83 |
+
class MaxPool2d(nn.MaxPool2d, RelPropSimple):
|
84 |
+
pass
|
85 |
+
|
86 |
+
class LayerNorm(nn.LayerNorm, RelProp):
|
87 |
+
pass
|
88 |
+
|
89 |
+
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
|
90 |
+
pass
|
91 |
+
|
92 |
+
|
93 |
+
class AvgPool2d(nn.AvgPool2d, RelPropSimple):
|
94 |
+
pass
|
95 |
+
|
96 |
+
|
97 |
+
class Add(RelPropSimple):
|
98 |
+
def forward(self, inputs):
|
99 |
+
return torch.add(*inputs)
|
100 |
+
|
101 |
+
def relprop(self, R, alpha):
|
102 |
+
Z = self.forward(self.X)
|
103 |
+
S = safe_divide(R, Z)
|
104 |
+
C = self.gradprop(Z, self.X, S)
|
105 |
+
|
106 |
+
a = self.X[0] * C[0]
|
107 |
+
b = self.X[1] * C[1]
|
108 |
+
|
109 |
+
a_sum = a.sum()
|
110 |
+
b_sum = b.sum()
|
111 |
+
|
112 |
+
a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
|
113 |
+
b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
|
114 |
+
|
115 |
+
a = a * safe_divide(a_fact, a.sum())
|
116 |
+
b = b * safe_divide(b_fact, b.sum())
|
117 |
+
|
118 |
+
outputs = [a, b]
|
119 |
+
|
120 |
+
return outputs
|
121 |
+
|
122 |
+
class einsum(RelPropSimple):
|
123 |
+
def __init__(self, equation):
|
124 |
+
super().__init__()
|
125 |
+
self.equation = equation
|
126 |
+
def forward(self, *operands):
|
127 |
+
return torch.einsum(self.equation, *operands)
|
128 |
+
|
129 |
+
class IndexSelect(RelProp):
|
130 |
+
def forward(self, inputs, dim, indices):
|
131 |
+
self.__setattr__('dim', dim)
|
132 |
+
self.__setattr__('indices', indices)
|
133 |
+
|
134 |
+
return torch.index_select(inputs, dim, indices)
|
135 |
+
|
136 |
+
def relprop(self, R, alpha):
|
137 |
+
Z = self.forward(self.X, self.dim, self.indices)
|
138 |
+
S = safe_divide(R, Z)
|
139 |
+
C = self.gradprop(Z, self.X, S)
|
140 |
+
|
141 |
+
if torch.is_tensor(self.X) == False:
|
142 |
+
outputs = []
|
143 |
+
outputs.append(self.X[0] * C[0])
|
144 |
+
outputs.append(self.X[1] * C[1])
|
145 |
+
else:
|
146 |
+
outputs = self.X * (C[0])
|
147 |
+
return outputs
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
class Clone(RelProp):
|
152 |
+
def forward(self, input, num):
|
153 |
+
self.__setattr__('num', num)
|
154 |
+
outputs = []
|
155 |
+
for _ in range(num):
|
156 |
+
outputs.append(input)
|
157 |
+
|
158 |
+
return outputs
|
159 |
+
|
160 |
+
def relprop(self, R, alpha):
|
161 |
+
Z = []
|
162 |
+
for _ in range(self.num):
|
163 |
+
Z.append(self.X)
|
164 |
+
S = [safe_divide(r, z) for r, z in zip(R, Z)]
|
165 |
+
C = self.gradprop(Z, self.X, S)[0]
|
166 |
+
|
167 |
+
R = self.X * C
|
168 |
+
|
169 |
+
return R
|
170 |
+
|
171 |
+
class Cat(RelProp):
|
172 |
+
def forward(self, inputs, dim):
|
173 |
+
self.__setattr__('dim', dim)
|
174 |
+
return torch.cat(inputs, dim)
|
175 |
+
|
176 |
+
def relprop(self, R, alpha):
|
177 |
+
Z = self.forward(self.X, self.dim)
|
178 |
+
S = safe_divide(R, Z)
|
179 |
+
C = self.gradprop(Z, self.X, S)
|
180 |
+
|
181 |
+
outputs = []
|
182 |
+
for x, c in zip(self.X, C):
|
183 |
+
outputs.append(x * c)
|
184 |
+
|
185 |
+
return outputs
|
186 |
+
|
187 |
+
|
188 |
+
class Sequential(nn.Sequential):
|
189 |
+
def relprop(self, R, alpha):
|
190 |
+
for m in reversed(self._modules.values()):
|
191 |
+
R = m.relprop(R, alpha)
|
192 |
+
return R
|
193 |
+
|
194 |
+
class BatchNorm2d(nn.BatchNorm2d, RelProp):
|
195 |
+
def relprop(self, R, alpha):
|
196 |
+
X = self.X
|
197 |
+
beta = 1 - alpha
|
198 |
+
weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
|
199 |
+
(self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
|
200 |
+
Z = X * weight + 1e-9
|
201 |
+
S = R / Z
|
202 |
+
Ca = S * weight
|
203 |
+
R = self.X * (Ca)
|
204 |
+
return R
|
205 |
+
|
206 |
+
|
207 |
+
class Linear(nn.Linear, RelProp):
|
208 |
+
def relprop(self, R, alpha):
|
209 |
+
beta = alpha - 1
|
210 |
+
pw = torch.clamp(self.weight, min=0)
|
211 |
+
nw = torch.clamp(self.weight, max=0)
|
212 |
+
px = torch.clamp(self.X, min=0)
|
213 |
+
nx = torch.clamp(self.X, max=0)
|
214 |
+
|
215 |
+
def f(w1, w2, x1, x2):
|
216 |
+
Z1 = F.linear(x1, w1)
|
217 |
+
Z2 = F.linear(x2, w2)
|
218 |
+
S1 = safe_divide(R, Z1 + Z2)
|
219 |
+
S2 = safe_divide(R, Z1 + Z2)
|
220 |
+
C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
|
221 |
+
C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
|
222 |
+
|
223 |
+
return C1 + C2
|
224 |
+
|
225 |
+
activator_relevances = f(pw, nw, px, nx)
|
226 |
+
inhibitor_relevances = f(nw, pw, px, nx)
|
227 |
+
|
228 |
+
R = alpha * activator_relevances - beta * inhibitor_relevances
|
229 |
+
|
230 |
+
return R
|
231 |
+
|
232 |
+
|
233 |
+
class Conv2d(nn.Conv2d, RelProp):
|
234 |
+
def gradprop2(self, DY, weight):
|
235 |
+
Z = self.forward(self.X)
|
236 |
+
|
237 |
+
output_padding = self.X.size()[2] - (
|
238 |
+
(Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
|
239 |
+
|
240 |
+
return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
|
241 |
+
|
242 |
+
def relprop(self, R, alpha):
|
243 |
+
if self.X.shape[1] == 3:
|
244 |
+
pw = torch.clamp(self.weight, min=0)
|
245 |
+
nw = torch.clamp(self.weight, max=0)
|
246 |
+
X = self.X
|
247 |
+
L = self.X * 0 + \
|
248 |
+
torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
|
249 |
+
keepdim=True)[0]
|
250 |
+
H = self.X * 0 + \
|
251 |
+
torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
|
252 |
+
keepdim=True)[0]
|
253 |
+
Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
|
254 |
+
torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
|
255 |
+
torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
|
256 |
+
|
257 |
+
S = R / Za
|
258 |
+
C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
|
259 |
+
R = C
|
260 |
+
else:
|
261 |
+
beta = alpha - 1
|
262 |
+
pw = torch.clamp(self.weight, min=0)
|
263 |
+
nw = torch.clamp(self.weight, max=0)
|
264 |
+
px = torch.clamp(self.X, min=0)
|
265 |
+
nx = torch.clamp(self.X, max=0)
|
266 |
+
|
267 |
+
def f(w1, w2, x1, x2):
|
268 |
+
Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
|
269 |
+
Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
|
270 |
+
S1 = safe_divide(R, Z1)
|
271 |
+
S2 = safe_divide(R, Z2)
|
272 |
+
C1 = x1 * self.gradprop(Z1, x1, S1)[0]
|
273 |
+
C2 = x2 * self.gradprop(Z2, x2, S2)[0]
|
274 |
+
return C1 + C2
|
275 |
+
|
276 |
+
activator_relevances = f(pw, nw, px, nx)
|
277 |
+
inhibitor_relevances = f(nw, pw, px, nx)
|
278 |
+
|
279 |
+
R = alpha * activator_relevances - beta * inhibitor_relevances
|
280 |
+
return R
|
models/weight_init.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding:UTF-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.init as init
|
7 |
+
|
8 |
+
|
9 |
+
def weight_init(m):
|
10 |
+
'''
|
11 |
+
Usage:
|
12 |
+
model = Model()
|
13 |
+
model.apply(weight_init)
|
14 |
+
'''
|
15 |
+
if isinstance(m, nn.Conv1d):
|
16 |
+
init.normal_(m.weight.data)
|
17 |
+
if m.bias is not None:
|
18 |
+
init.normal_(m.bias.data)
|
19 |
+
elif isinstance(m, nn.Conv2d):
|
20 |
+
init.xavier_normal_(m.weight.data)
|
21 |
+
if m.bias is not None:
|
22 |
+
init.normal_(m.bias.data)
|
23 |
+
elif isinstance(m, nn.Conv3d):
|
24 |
+
init.xavier_normal_(m.weight.data)
|
25 |
+
if m.bias is not None:
|
26 |
+
init.normal_(m.bias.data)
|
27 |
+
elif isinstance(m, nn.ConvTranspose1d):
|
28 |
+
init.normal_(m.weight.data)
|
29 |
+
if m.bias is not None:
|
30 |
+
init.normal_(m.bias.data)
|
31 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
32 |
+
init.xavier_normal_(m.weight.data)
|
33 |
+
if m.bias is not None:
|
34 |
+
init.normal_(m.bias.data)
|
35 |
+
elif isinstance(m, nn.ConvTranspose3d):
|
36 |
+
init.xavier_normal_(m.weight.data)
|
37 |
+
if m.bias is not None:
|
38 |
+
init.normal_(m.bias.data)
|
39 |
+
elif isinstance(m, nn.BatchNorm1d):
|
40 |
+
init.normal_(m.weight.data, mean=1, std=0.02)
|
41 |
+
init.constant_(m.bias.data, 0)
|
42 |
+
elif isinstance(m, nn.BatchNorm2d):
|
43 |
+
init.normal_(m.weight.data, mean=1, std=0.02)
|
44 |
+
init.constant_(m.bias.data, 0)
|
45 |
+
elif isinstance(m, nn.BatchNorm3d):
|
46 |
+
init.normal_(m.weight.data, mean=1, std=0.02)
|
47 |
+
init.constant_(m.bias.data, 0)
|
48 |
+
elif isinstance(m, nn.Linear):
|
49 |
+
init.xavier_normal_(m.weight.data)
|
50 |
+
init.normal_(m.bias.data)
|
51 |
+
elif isinstance(m, nn.LSTM):
|
52 |
+
for param in m.parameters():
|
53 |
+
if len(param.shape) >= 2:
|
54 |
+
init.orthogonal_(param.data)
|
55 |
+
else:
|
56 |
+
init.normal_(param.data)
|
57 |
+
elif isinstance(m, nn.LSTMCell):
|
58 |
+
for param in m.parameters():
|
59 |
+
if len(param.shape) >= 2:
|
60 |
+
init.orthogonal_(param.data)
|
61 |
+
else:
|
62 |
+
init.normal_(param.data)
|
63 |
+
elif isinstance(m, nn.GRU):
|
64 |
+
for param in m.parameters():
|
65 |
+
if len(param.shape) >= 2:
|
66 |
+
init.orthogonal_(param.data)
|
67 |
+
else:
|
68 |
+
init.normal_(param.data)
|
69 |
+
elif isinstance(m, nn.GRUCell):
|
70 |
+
for param in m.parameters():
|
71 |
+
if len(param.shape) >= 2:
|
72 |
+
init.orthogonal_(param.data)
|
73 |
+
else:
|
74 |
+
init.normal_(param.data)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
pass
|
option.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################################
|
2 |
+
# Created by: YI ZHENG
|
3 |
+
# Email: [email protected]
|
4 |
+
# Copyright (c) 2020
|
5 |
+
###########################################################################
|
6 |
+
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
|
11 |
+
class Options():
|
12 |
+
def __init__(self):
|
13 |
+
parser = argparse.ArgumentParser(description='PyTorch Classification')
|
14 |
+
parser.add_argument('--data_path', type=str, help='path to dataset where images store')
|
15 |
+
parser.add_argument('--train_set', type=str, help='train')
|
16 |
+
parser.add_argument('--val_set', type=str, help='validation')
|
17 |
+
parser.add_argument('--model_path', type=str, help='path to trained model')
|
18 |
+
parser.add_argument('--log_path', type=str, help='path to log files')
|
19 |
+
parser.add_argument('--task_name', type=str, help='task name for naming saved model files and log files')
|
20 |
+
parser.add_argument('--train', action='store_true', default=False, help='train only')
|
21 |
+
parser.add_argument('--test', action='store_true', default=False, help='test only')
|
22 |
+
parser.add_argument('--batch_size', type=int, default=6, help='batch size for origin global image (without downsampling)')
|
23 |
+
parser.add_argument('--log_interval_local', type=int, default=10, help='classification classes')
|
24 |
+
parser.add_argument('--resume', type=str, default="", help='path for model')
|
25 |
+
parser.add_argument('--graphcam', action='store_true', default=False, help='GraphCAM')
|
26 |
+
parser.add_argument('--dataset_metadata_path', type=str, help='Location of the metadata associated with the created dataset: label mapping, splits and so on')
|
27 |
+
|
28 |
+
|
29 |
+
# the parser
|
30 |
+
self.parser = parser
|
31 |
+
|
32 |
+
def parse(self):
|
33 |
+
args = self.parser.parse_args()
|
34 |
+
# default settings for epochs and lr
|
35 |
+
|
36 |
+
args.num_epochs = 120
|
37 |
+
args.lr = 1e-3
|
38 |
+
|
39 |
+
if args.test:
|
40 |
+
args.num_epochs = 1
|
41 |
+
return args
|