shengz commited on
Commit
1f05155
1 Parent(s): 27005c2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +184 -1
README.md CHANGED
@@ -46,7 +46,190 @@ BiomedCLIP establishes new state of the art in a wide range of standard datasets
46
 
47
  ## Model Use
48
 
49
- ### How to use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  Please refer to this [example notebook](https://aka.ms/biomedclip-example-notebook).
52
 
 
46
 
47
  ## Model Use
48
 
49
+ ### 1. Environment
50
+
51
+ ```bash
52
+ conda create -n biomedclip python=3.10 -y
53
+ conda activate biomedclip
54
+ pip install open_clip_torch==2.23.0 transformers==4.35.2 matplotlib
55
+ ```
56
+
57
+ ### 2.1 Load from HF hub
58
+
59
+ ```python
60
+ import torch
61
+ from urllib.request import urlopen
62
+ from PIL import Image
63
+ from open_clip import create_model_from_pretrained, get_tokenizer
64
+
65
+ # Load the model and config files from the Hugging Face Hub
66
+ model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
67
+ tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
68
+
69
+
70
+ # Zero-shot image classification
71
+ template = 'this is a photo of '
72
+ labels = [
73
+ 'adenocarcinoma histopathology',
74
+ 'brain MRI',
75
+ 'covid line chart',
76
+ 'squamous cell carcinoma histopathology',
77
+ 'immunohistochemistry histopathology',
78
+ 'bone X-ray',
79
+ 'chest X-ray',
80
+ 'pie chart',
81
+ 'hematoxylin and eosin histopathology'
82
+ ]
83
+
84
+ dataset_url = 'https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/resolve/main/example_data/biomed_image_classification_example_data/'
85
+ test_imgs = [
86
+ 'squamous_cell_carcinoma_histopathology.jpeg',
87
+ 'H_and_E_histopathology.jpg',
88
+ 'bone_X-ray.jpg',
89
+ 'adenocarcinoma_histopathology.jpg',
90
+ 'covid_line_chart.png',
91
+ 'IHC_histopathology.jpg',
92
+ 'chest_X-ray.jpg',
93
+ 'brain_MRI.jpg',
94
+ 'pie_chart.png'
95
+ ]
96
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
97
+ model.to(device)
98
+ model.eval()
99
+
100
+ context_length = 256
101
+
102
+ images = torch.stack([preprocess(Image.open(urlopen(dataset_url + img))) for img in test_imgs]).to(device)
103
+ texts = tokenizer([template + l for l in labels], context_length=context_length).to(device)
104
+ with torch.no_grad():
105
+ image_features, text_features, logit_scale = model(images, texts)
106
+
107
+ logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
108
+ sorted_indices = torch.argsort(logits, dim=-1, descending=True)
109
+
110
+ logits = logits.cpu().numpy()
111
+ sorted_indices = sorted_indices.cpu().numpy()
112
+
113
+ top_k = -1
114
+
115
+ for i, img in enumerate(test_imgs):
116
+ pred = labels[sorted_indices[i][0]]
117
+
118
+ top_k = len(labels) if top_k == -1 else top_k
119
+ print(img.split('/')[-1] + ':')
120
+ for j in range(top_k):
121
+ jth_index = sorted_indices[i][j]
122
+ print(f'{labels[jth_index]}: {logits[i][jth_index]}')
123
+ print('\n')
124
+ ```
125
+
126
+ ### 2.2 Load from local files
127
+
128
+ ```python
129
+ import json
130
+
131
+ from urllib.request import urlopen
132
+ from PIL import Image
133
+ import torch
134
+ from huggingface_hub import hf_hub_download
135
+ from open_clip import create_model_and_transforms, get_tokenizer
136
+ from open_clip.factory import HF_HUB_PREFIX, _MODEL_CONFIGS
137
+
138
+
139
+ # Download the model and config files
140
+ hf_hub_download(
141
+ repo_id="microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224",
142
+ filename="open_clip_pytorch_model.bin",
143
+ local_dir="checkpoints"
144
+ )
145
+ hf_hub_download(
146
+ repo_id="microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224",
147
+ filename="open_clip_config.json",
148
+ local_dir="checkpoints"
149
+ )
150
+
151
+
152
+ # Load the model and config files
153
+ model_name = "biomedclip_local"
154
+
155
+ with open("checkpoints/open_clip_config.json", "r") as f:
156
+ config = json.load(f)
157
+ model_cfg = config["model_cfg"]
158
+ preprocess_cfg = config["preprocess_cfg"]
159
+
160
+
161
+ if (not model_name.startswith(HF_HUB_PREFIX)
162
+ and model_name not in _MODEL_CONFIGS
163
+ and config is not None):
164
+ _MODEL_CONFIGS[model_name] = model_cfg
165
+
166
+ tokenizer = get_tokenizer(model_name)
167
+
168
+ model, _, preprocess = create_model_and_transforms(
169
+ model_name=model_name,
170
+ pretrained="checkpoints/open_clip_pytorch_model.bin",
171
+ **{f"image_{k}": v for k, v in preprocess_cfg.items()},
172
+ )
173
+
174
+
175
+ # Zero-shot image classification
176
+ template = 'this is a photo of '
177
+ labels = [
178
+ 'adenocarcinoma histopathology',
179
+ 'brain MRI',
180
+ 'covid line chart',
181
+ 'squamous cell carcinoma histopathology',
182
+ 'immunohistochemistry histopathology',
183
+ 'bone X-ray',
184
+ 'chest X-ray',
185
+ 'pie chart',
186
+ 'hematoxylin and eosin histopathology'
187
+ ]
188
+
189
+ dataset_url = 'https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/resolve/main/example_data/biomed_image_classification_example_data/'
190
+ test_imgs = [
191
+ 'squamous_cell_carcinoma_histopathology.jpeg',
192
+ 'H_and_E_histopathology.jpg',
193
+ 'bone_X-ray.jpg',
194
+ 'adenocarcinoma_histopathology.jpg',
195
+ 'covid_line_chart.png',
196
+ 'IHC_histopathology.jpg',
197
+ 'chest_X-ray.jpg',
198
+ 'brain_MRI.jpg',
199
+ 'pie_chart.png'
200
+ ]
201
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
202
+ model.to(device)
203
+ model.eval()
204
+
205
+ context_length = 256
206
+
207
+ images = torch.stack([preprocess(Image.open(urlopen(dataset_url + img))) for img in test_imgs]).to(device)
208
+ texts = tokenizer([template + l for l in labels], context_length=context_length).to(device)
209
+ with torch.no_grad():
210
+ image_features, text_features, logit_scale = model(images, texts)
211
+
212
+ logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
213
+ sorted_indices = torch.argsort(logits, dim=-1, descending=True)
214
+
215
+ logits = logits.cpu().numpy()
216
+ sorted_indices = sorted_indices.cpu().numpy()
217
+
218
+ top_k = -1
219
+
220
+ for i, img in enumerate(test_imgs):
221
+ pred = labels[sorted_indices[i][0]]
222
+
223
+ top_k = len(labels) if top_k == -1 else top_k
224
+ print(img.split('/')[-1] + ':')
225
+ for j in range(top_k):
226
+ jth_index = sorted_indices[i][j]
227
+ print(f'{labels[jth_index]}: {logits[i][jth_index]}')
228
+ print('\n')
229
+
230
+ ```
231
+
232
+ ### Use in Jupyter Notebook
233
 
234
  Please refer to this [example notebook](https://aka.ms/biomedclip-example-notebook).
235