Spaces:
Sleeping
Sleeping
Ben Feuer
commited on
Commit
•
ec3c973
1
Parent(s):
1e0e251
BioTrove Demo
Browse files- .gitattributes +4 -0
- .gitignore +2 -0
- README.md +5 -8
- app.py +303 -0
- components/metadata.csv +3 -0
- components/metadata_readme.md +11 -0
- components/query.py +116 -0
- components/sync_samples_to_s3.bash +34 -0
- embed_texts.sh +12 -0
- examples/Actinostola-abyssorum.png +3 -0
- examples/Amanita-muscaria.jpeg +3 -0
- examples/Carnegiea-gigantea.png +3 -0
- examples/Felis-catus.jpeg +3 -0
- examples/Onoclea-hintonii.jpg +0 -0
- examples/Onoclea-sensibilis.jpg +0 -0
- examples/Phoca-vitulina.png +3 -0
- examples/Sarcoscypha-coccinea.jpeg +3 -0
- examples/Ursus-arctos.jpeg +3 -0
- examples/coral-snake.jpeg +3 -0
- examples/milk-snake.png +3 -0
- lib.py +170 -0
- make_txt_embedding.py +193 -0
- name_lookup.json +3 -0
- requirements.txt +7 -0
- templates.py +82 -0
- test_lib.py +481 -0
- txt_emb.npy +3 -0
- txt_emb_species.json +3 -0
- txt_emb_species.npy +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
components/metadata.csv filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.venv/
|
2 |
+
__pycache__/
|
README.md
CHANGED
@@ -1,14 +1,11 @@
|
|
1 |
---
|
2 |
title: BioTrove Demo
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
-
|
12 |
-
---
|
13 |
-
|
14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: BioTrove Demo
|
3 |
+
emoji: 🐘
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
+
---
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import heapq
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import polars as pl
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from open_clip import create_model, get_tokenizer
|
13 |
+
from torchvision import transforms
|
14 |
+
|
15 |
+
from templates import openai_imagenet_template
|
16 |
+
from components.query import get_sample
|
17 |
+
|
18 |
+
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
|
19 |
+
logging.basicConfig(level=logging.INFO, format=log_format)
|
20 |
+
logger = logging.getLogger()
|
21 |
+
|
22 |
+
hf_token = os.getenv("HF_TOKEN")
|
23 |
+
|
24 |
+
# For sample images
|
25 |
+
METADATA_PATH = "components/metadata.csv"
|
26 |
+
# Read page ID as int and filter out smaller ablation duplicated training split
|
27 |
+
metadata_df = pl.read_csv(METADATA_PATH, low_memory = False)
|
28 |
+
metadata_df = metadata_df.with_columns(pl.col("eol_page_id").cast(pl.Int64))
|
29 |
+
|
30 |
+
model_str = "hf-hub:penfever/biotrove-o"
|
31 |
+
tokenizer_str = "ViT-B-16"
|
32 |
+
|
33 |
+
txt_emb_npy = "txt_emb_species.npy"
|
34 |
+
txt_names_json = "txt_emb_species.json"
|
35 |
+
|
36 |
+
min_prob = 1e-9
|
37 |
+
k = 5
|
38 |
+
|
39 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
40 |
+
|
41 |
+
preprocess_img = transforms.Compose(
|
42 |
+
[
|
43 |
+
transforms.ToTensor(),
|
44 |
+
transforms.Resize((224, 224), antialias=True),
|
45 |
+
transforms.Normalize(
|
46 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
47 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
48 |
+
),
|
49 |
+
]
|
50 |
+
)
|
51 |
+
|
52 |
+
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
|
53 |
+
|
54 |
+
open_domain_examples = [
|
55 |
+
["examples/Ursus-arctos.jpeg", "Species"],
|
56 |
+
["examples/Phoca-vitulina.png", "Species"],
|
57 |
+
["examples/Felis-catus.jpeg", "Genus"],
|
58 |
+
["examples/Sarcoscypha-coccinea.jpeg", "Order"],
|
59 |
+
]
|
60 |
+
zero_shot_examples = [
|
61 |
+
[
|
62 |
+
"examples/Ursus-arctos.jpeg",
|
63 |
+
"brown bear\nblack bear\npolar bear\nkoala bear\ngrizzly bear",
|
64 |
+
],
|
65 |
+
["examples/milk-snake.png", "coral snake\nmilk snake"],
|
66 |
+
["examples/coral-snake.jpeg", "coral snake\nmilk snake"],
|
67 |
+
[
|
68 |
+
"examples/Carnegiea-gigantea.png",
|
69 |
+
"Carnegiea gigantea\nSchlumbergera opuntioides\nMammillaria albicoma",
|
70 |
+
],
|
71 |
+
[
|
72 |
+
"examples/Amanita-muscaria.jpeg",
|
73 |
+
"Amanita fulva\nAmanita vaginata (grisette)\nAmanita calyptrata (coccoli)\nAmanita crocea\nAmanita rubescens (blusher)\nAmanita caesarea (Caesar's mushroom)\nAmanita jacksonii (American Caesar's mushroom)\nAmanita muscaria (fly agaric)\nAmanita pantherina (panther cap)",
|
74 |
+
],
|
75 |
+
[
|
76 |
+
"examples/Actinostola-abyssorum.png",
|
77 |
+
"Animalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola abyssorum\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola bulbosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola callosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola capensis\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola carlgreni",
|
78 |
+
],
|
79 |
+
[
|
80 |
+
"examples/Sarcoscypha-coccinea.jpeg",
|
81 |
+
"scarlet elf cup (coccinea)\nscharlachroter kelchbecherling (austriaca)\ncrimson cup (dudleyi)\nstalked scarlet cup (occidentalis)",
|
82 |
+
],
|
83 |
+
[
|
84 |
+
"examples/Onoclea-hintonii.jpg",
|
85 |
+
"Onoclea attenuata\nOnoclea boryana\nOnoclea hintonii\nOnoclea intermedia\nOnoclea sensibilis",
|
86 |
+
],
|
87 |
+
[
|
88 |
+
"examples/Onoclea-sensibilis.jpg",
|
89 |
+
"Onoclea attenuata\nOnoclea boryana\nOnoclea hintonii\nOnoclea intermedia\nOnoclea sensibilis",
|
90 |
+
],
|
91 |
+
]
|
92 |
+
|
93 |
+
|
94 |
+
def indexed(lst, indices):
|
95 |
+
return [lst[i] for i in indices]
|
96 |
+
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def get_txt_features(classnames, templates):
|
100 |
+
all_features = []
|
101 |
+
for classname in classnames:
|
102 |
+
txts = [template(classname) for template in templates]
|
103 |
+
txts = tokenizer(txts).to(device)
|
104 |
+
txt_features = model.encode_text(txts)
|
105 |
+
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
|
106 |
+
txt_features /= txt_features.norm()
|
107 |
+
all_features.append(txt_features)
|
108 |
+
all_features = torch.stack(all_features, dim=1)
|
109 |
+
return all_features
|
110 |
+
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
|
114 |
+
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
|
115 |
+
txt_features = get_txt_features(classes, openai_imagenet_template)
|
116 |
+
|
117 |
+
img = preprocess_img(img).to(device)
|
118 |
+
img_features = model.encode_image(img.unsqueeze(0))
|
119 |
+
img_features = F.normalize(img_features, dim=-1)
|
120 |
+
|
121 |
+
logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
|
122 |
+
probs = F.softmax(logits, dim=0).to("cpu").tolist()
|
123 |
+
return {cls: prob for cls, prob in zip(classes, probs)}
|
124 |
+
|
125 |
+
|
126 |
+
def format_name(taxon, common):
|
127 |
+
taxon = " ".join(taxon)
|
128 |
+
if not common:
|
129 |
+
return taxon
|
130 |
+
return f"{taxon} ({common})"
|
131 |
+
|
132 |
+
|
133 |
+
@torch.no_grad()
|
134 |
+
def open_domain_classification(img, rank: int, return_all=False):
|
135 |
+
"""
|
136 |
+
Predicts from the entire tree of life.
|
137 |
+
If targeting a higher rank than species, then this function predicts among all
|
138 |
+
species, then sums up species-level probabilities for the given rank.
|
139 |
+
"""
|
140 |
+
|
141 |
+
logger.info(f"Starting open domain classification for rank: {rank}")
|
142 |
+
img = preprocess_img(img).to(device)
|
143 |
+
img_features = model.encode_image(img.unsqueeze(0))
|
144 |
+
img_features = F.normalize(img_features, dim=-1)
|
145 |
+
|
146 |
+
logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
|
147 |
+
probs = F.softmax(logits, dim=0)
|
148 |
+
|
149 |
+
if rank + 1 == len(ranks):
|
150 |
+
topk = probs.topk(k)
|
151 |
+
prediction_dict = {
|
152 |
+
format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
|
153 |
+
}
|
154 |
+
logger.info(f"Top K predictions: {prediction_dict}")
|
155 |
+
top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
|
156 |
+
logger.info(f"Top prediction name: {top_prediction_name}")
|
157 |
+
sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
|
158 |
+
if return_all:
|
159 |
+
return prediction_dict, sample_img, taxon_url
|
160 |
+
return prediction_dict
|
161 |
+
|
162 |
+
output = collections.defaultdict(float)
|
163 |
+
for i in torch.nonzero(probs > min_prob).squeeze():
|
164 |
+
output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
|
165 |
+
|
166 |
+
topk_names = heapq.nlargest(k, output, key=output.get)
|
167 |
+
prediction_dict = {name: output[name] for name in topk_names}
|
168 |
+
logger.info(f"Top K names for output: {topk_names}")
|
169 |
+
logger.info(f"Prediction dictionary: {prediction_dict}")
|
170 |
+
|
171 |
+
top_prediction_name = topk_names[0]
|
172 |
+
logger.info(f"Top prediction name: {top_prediction_name}")
|
173 |
+
sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
|
174 |
+
logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
|
175 |
+
|
176 |
+
if return_all:
|
177 |
+
return prediction_dict, sample_img, taxon_url
|
178 |
+
return prediction_dict
|
179 |
+
|
180 |
+
|
181 |
+
def change_output(choice):
|
182 |
+
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
logger.info("Starting.")
|
187 |
+
model = create_model(model_str, output_dict=True, require_pretrained=True)
|
188 |
+
model = model.to(device)
|
189 |
+
logger.info("Created model.")
|
190 |
+
|
191 |
+
model = torch.compile(model)
|
192 |
+
logger.info("Compiled model.")
|
193 |
+
|
194 |
+
tokenizer = get_tokenizer(tokenizer_str)
|
195 |
+
|
196 |
+
txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device)
|
197 |
+
with open(txt_names_json) as fd:
|
198 |
+
txt_names = json.load(fd)
|
199 |
+
|
200 |
+
done = txt_emb.any(axis=0).sum().item()
|
201 |
+
total = txt_emb.shape[1]
|
202 |
+
status_msg = ""
|
203 |
+
if done != total:
|
204 |
+
status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
|
205 |
+
|
206 |
+
with gr.Blocks() as app:
|
207 |
+
|
208 |
+
with gr.Tab("Open-Ended"):
|
209 |
+
with gr.Row(variant = "panel", elem_id = "images_panel"):
|
210 |
+
with gr.Column():
|
211 |
+
img_input = gr.Image(height = 400, sources=["upload"])
|
212 |
+
|
213 |
+
with gr.Column():
|
214 |
+
# display sample image of top predicted taxon
|
215 |
+
sample_img = gr.Image(label = "Sample Image of Predicted Taxon",
|
216 |
+
height = 400,
|
217 |
+
show_download_button = False)
|
218 |
+
|
219 |
+
taxon_url = gr.HTML(label = "More Information",
|
220 |
+
elem_id = "url"
|
221 |
+
)
|
222 |
+
|
223 |
+
with gr.Row():
|
224 |
+
with gr.Column():
|
225 |
+
rank_dropdown = gr.Dropdown(
|
226 |
+
label="Taxonomic Rank",
|
227 |
+
info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
|
228 |
+
choices=ranks,
|
229 |
+
value="Species",
|
230 |
+
type="index",
|
231 |
+
)
|
232 |
+
open_domain_btn = gr.Button("Submit", variant="primary")
|
233 |
+
with gr.Column():
|
234 |
+
open_domain_output = gr.Label(
|
235 |
+
num_top_classes=k,
|
236 |
+
label="Prediction",
|
237 |
+
show_label=True,
|
238 |
+
value=None,
|
239 |
+
)
|
240 |
+
# open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
|
241 |
+
|
242 |
+
with gr.Row():
|
243 |
+
gr.Examples(
|
244 |
+
examples=open_domain_examples,
|
245 |
+
inputs=[img_input, rank_dropdown],
|
246 |
+
cache_examples=True,
|
247 |
+
fn=lambda img, rank: open_domain_classification(img, rank, return_all=False),
|
248 |
+
outputs=[open_domain_output],
|
249 |
+
)
|
250 |
+
with gr.Tab("Zero-Shot"):
|
251 |
+
with gr.Row():
|
252 |
+
img_input_zs = gr.Image(height = 400, sources=["upload"])
|
253 |
+
|
254 |
+
with gr.Row():
|
255 |
+
with gr.Column():
|
256 |
+
classes_txt = gr.Textbox(
|
257 |
+
placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
|
258 |
+
lines=3,
|
259 |
+
label="Classes",
|
260 |
+
show_label=True,
|
261 |
+
info="Use taxonomic names where possible; include common names if possible.",
|
262 |
+
)
|
263 |
+
zero_shot_btn = gr.Button("Submit", variant="primary")
|
264 |
+
|
265 |
+
with gr.Column():
|
266 |
+
zero_shot_output = gr.Label(
|
267 |
+
num_top_classes=k, label="Prediction", show_label=True
|
268 |
+
)
|
269 |
+
# zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
|
270 |
+
|
271 |
+
with gr.Row():
|
272 |
+
gr.Examples(
|
273 |
+
examples=zero_shot_examples,
|
274 |
+
inputs=[img_input_zs, classes_txt],
|
275 |
+
cache_examples=True,
|
276 |
+
fn=zero_shot_classification,
|
277 |
+
outputs=[zero_shot_output],
|
278 |
+
)
|
279 |
+
rank_dropdown.change(
|
280 |
+
fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
|
281 |
+
)
|
282 |
+
|
283 |
+
open_domain_btn.click(
|
284 |
+
fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
|
285 |
+
inputs=[img_input, rank_dropdown],
|
286 |
+
outputs=[open_domain_output, sample_img, taxon_url],
|
287 |
+
)
|
288 |
+
|
289 |
+
zero_shot_btn.click(
|
290 |
+
fn=zero_shot_classification,
|
291 |
+
inputs=[img_input_zs, classes_txt],
|
292 |
+
outputs=zero_shot_output,
|
293 |
+
)
|
294 |
+
|
295 |
+
# Footer to point out to model and data from app page.
|
296 |
+
gr.Markdown(
|
297 |
+
"""
|
298 |
+
TODO: Add footer with model and data information.
|
299 |
+
"""
|
300 |
+
)
|
301 |
+
|
302 |
+
app.queue(max_size=20)
|
303 |
+
app.launch(share=True)
|
components/metadata.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d8576f6ca106f35387506369a70df01fb92192a740c3b5da2a12ad8303976aad
|
3 |
+
size 233934143
|
components/metadata_readme.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Bioclip Demo
|
3 |
+
emoji: 🐘
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.36.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
components/query.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import boto3
|
3 |
+
import requests
|
4 |
+
import numpy as np
|
5 |
+
import polars as pl
|
6 |
+
from PIL import Image
|
7 |
+
from botocore.config import Config
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
# S3 for sample images
|
13 |
+
my_config = Config(
|
14 |
+
region_name='us-east-1'
|
15 |
+
)
|
16 |
+
s3_client = boto3.client('s3', config=my_config)
|
17 |
+
|
18 |
+
# Set basepath for EOL pages for info
|
19 |
+
EOL_URL = "https://eol.org/pages/"
|
20 |
+
RANKS = ["kingdom", "phylum", "class", "order", "family", "genus", "species"]
|
21 |
+
|
22 |
+
def get_sample(df, pred_taxon, rank):
|
23 |
+
'''
|
24 |
+
Function to retrieve a sample image of the predicted taxon and EOL page link for more info.
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
-----------
|
28 |
+
df : DataFrame
|
29 |
+
DataFrame with all sample images listed and their filepaths (in "file_path" column).
|
30 |
+
pred_taxon : str
|
31 |
+
Predicted taxon of the uploaded image.
|
32 |
+
rank : int
|
33 |
+
Index of rank in RANKS chosen for prediction.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
--------
|
37 |
+
img : PIL.Image
|
38 |
+
Sample image of predicted taxon for display.
|
39 |
+
eol_page : str
|
40 |
+
URL to EOL page for the taxon (may be a lower rank, e.g., species sample).
|
41 |
+
'''
|
42 |
+
logger.info(f"Getting sample for taxon: {pred_taxon} at rank: {rank}")
|
43 |
+
try:
|
44 |
+
filepath, eol_page_id, full_name, is_exact = get_sample_data(df, pred_taxon, rank)
|
45 |
+
except Exception as e:
|
46 |
+
logger.error(f"Error retrieving sample data: {e}")
|
47 |
+
return None, f"We encountered the following error trying to retrieve a sample image: {e}."
|
48 |
+
if filepath is None:
|
49 |
+
logger.warning(f"No sample image found for taxon: {pred_taxon}")
|
50 |
+
return None, f"Sorry, our EOL images do not include {pred_taxon}."
|
51 |
+
|
52 |
+
# Get sample image of selected individual
|
53 |
+
try:
|
54 |
+
img_src = s3_client.generate_presigned_url('get_object',
|
55 |
+
Params={'Bucket': 'treeoflife-10m-sample-images',
|
56 |
+
'Key': filepath}
|
57 |
+
)
|
58 |
+
img_resp = requests.get(img_src)
|
59 |
+
img = Image.open(io.BytesIO(img_resp.content))
|
60 |
+
full_eol_url = EOL_URL + eol_page_id
|
61 |
+
if is_exact:
|
62 |
+
eol_page = f"<p>Check out the EOL entry for {pred_taxon} to learn more: <a href={full_eol_url} target='_blank'>{full_eol_url}</a>.</p>"
|
63 |
+
else:
|
64 |
+
eol_page = f"<p>Check out an example EOL entry within {pred_taxon} to learn more: {full_name} <a href={full_eol_url} target='_blank'>{full_eol_url}</a>.</p>"
|
65 |
+
logger.info(f"Successfully retrieved sample image and EOL page for {pred_taxon}")
|
66 |
+
return img, eol_page
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f"Error retrieving sample image: {e}")
|
69 |
+
return None, f"We encountered the following error trying to retrieve a sample image: {e}."
|
70 |
+
|
71 |
+
def get_sample_data(df, pred_taxon, rank):
|
72 |
+
'''
|
73 |
+
Function to randomly select a sample individual of the given taxon and provide associated native location.
|
74 |
+
|
75 |
+
Parameters:
|
76 |
+
-----------
|
77 |
+
df : DataFrame
|
78 |
+
DataFrame with all sample images listed and their filepaths (in "file_path" column).
|
79 |
+
pred_taxon : str
|
80 |
+
Predicted taxon of the uploaded image.
|
81 |
+
rank : int
|
82 |
+
Index of rank in RANKS chosen for prediction.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
--------
|
86 |
+
filepath : str
|
87 |
+
Filepath of selected sample image for predicted taxon.
|
88 |
+
eol_page_id : str
|
89 |
+
EOL page ID associated with predicted taxon for more information.
|
90 |
+
full_name : str
|
91 |
+
Full taxonomic name of the selected sample.
|
92 |
+
is_exact : bool
|
93 |
+
Flag indicating if the match is exact (i.e., with empty lower ranks).
|
94 |
+
'''
|
95 |
+
for idx in range(rank + 1):
|
96 |
+
taxon = RANKS[idx]
|
97 |
+
target_taxon = pred_taxon.split(" ")[idx]
|
98 |
+
df = df.filter(pl.col(taxon) == target_taxon)
|
99 |
+
|
100 |
+
if df.shape[0] == 0:
|
101 |
+
return None, np.nan, "", False
|
102 |
+
|
103 |
+
# First, try to find entries with empty lower ranks
|
104 |
+
exact_df = df
|
105 |
+
for lower_rank in RANKS[rank + 1:]:
|
106 |
+
exact_df = exact_df.filter((pl.col(lower_rank).is_null()) | (pl.col(lower_rank) == ""))
|
107 |
+
|
108 |
+
if exact_df.shape[0] > 0:
|
109 |
+
df_filtered = exact_df.sample()
|
110 |
+
full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0))
|
111 |
+
return df_filtered["file_path"][0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, True
|
112 |
+
|
113 |
+
# If no exact matches, return any entry with the specified rank
|
114 |
+
df_filtered = df.sample()
|
115 |
+
full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0)) + " " + " ".join(df_filtered.select(RANKS[rank+1:]).row(0))
|
116 |
+
return df_filtered["file_path"][0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, False
|
components/sync_samples_to_s3.bash
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
<<COMMENT
|
4 |
+
Usage:
|
5 |
+
bash sync_samples_to_s3.bash <BASE_DIR>
|
6 |
+
|
7 |
+
Dependencies:
|
8 |
+
- awscli (https://aws.amazon.com/cli/)
|
9 |
+
Credentials to export as environment variables:
|
10 |
+
- AWS_ACCESS_KEY_ID
|
11 |
+
- AWS_SECRET_ACCESS_KEY
|
12 |
+
COMMENT
|
13 |
+
|
14 |
+
# Check if a valid directory is provided as an argument
|
15 |
+
if [ -z "$1" ]; then
|
16 |
+
echo "Usage: $0 <BASE_DIR>"
|
17 |
+
exit 1
|
18 |
+
fi
|
19 |
+
|
20 |
+
if [ ! -d "$1" ]; then
|
21 |
+
echo "Error: $1 is not a valid directory"
|
22 |
+
exit 1
|
23 |
+
fi
|
24 |
+
|
25 |
+
BASE_DIR="$1"
|
26 |
+
S3_BUCKET="s3://treeoflife-10m-sample-images"
|
27 |
+
|
28 |
+
# Loop through all directories and sync them to S3
|
29 |
+
for dir in $BASE_DIR/*; do
|
30 |
+
if [ -d "$dir" ]; then
|
31 |
+
dir_name=$(basename "$dir")
|
32 |
+
aws s3 sync "$dir" "$S3_BUCKET/$dir_name/"
|
33 |
+
fi
|
34 |
+
done
|
embed_texts.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
#SBATCH --nodes=1
|
3 |
+
#SBATCH --account=PAS2136
|
4 |
+
#SBATCH --gpus-per-node=1
|
5 |
+
#SBATCH --ntasks-per-node=10
|
6 |
+
#SBATCH --job-name=embed-treeoflife
|
7 |
+
#SBATCH --time=12:00:00
|
8 |
+
#SBATCH --partition=gpu
|
9 |
+
|
10 |
+
python make_txt_embedding.py \
|
11 |
+
--catalog-path /fs/ess/PAS2136/open_clip/data/evobio10m-v3.3/predicted-statistics.csv \
|
12 |
+
--out-path text_emb.bin
|
examples/Actinostola-abyssorum.png
ADDED
Git LFS Details
|
examples/Amanita-muscaria.jpeg
ADDED
Git LFS Details
|
examples/Carnegiea-gigantea.png
ADDED
Git LFS Details
|
examples/Felis-catus.jpeg
ADDED
Git LFS Details
|
examples/Onoclea-hintonii.jpg
ADDED
examples/Onoclea-sensibilis.jpg
ADDED
examples/Phoca-vitulina.png
ADDED
Git LFS Details
|
examples/Sarcoscypha-coccinea.jpeg
ADDED
Git LFS Details
|
examples/Ursus-arctos.jpeg
ADDED
Git LFS Details
|
examples/coral-snake.jpeg
ADDED
Git LFS Details
|
examples/milk-snake.png
ADDED
Git LFS Details
|
lib.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Mostly a TaxonomicTree class that implements a taxonomy and some helpers for easily
|
3 |
+
walking and looking in the tree.
|
4 |
+
|
5 |
+
A tree is an arrangement of TaxonomicNodes.
|
6 |
+
|
7 |
+
|
8 |
+
"""
|
9 |
+
|
10 |
+
|
11 |
+
import itertools
|
12 |
+
import json
|
13 |
+
|
14 |
+
|
15 |
+
class TaxonomicNode:
|
16 |
+
__slots__ = ("name", "index", "root", "_children")
|
17 |
+
|
18 |
+
def __init__(self, name, index, root):
|
19 |
+
self.name = name
|
20 |
+
self.index = index
|
21 |
+
self.root = root
|
22 |
+
self._children = {}
|
23 |
+
|
24 |
+
def add(self, name):
|
25 |
+
added = 0
|
26 |
+
if not name:
|
27 |
+
return added
|
28 |
+
|
29 |
+
first, rest = name[0], name[1:]
|
30 |
+
if first not in self._children:
|
31 |
+
self._children[first] = TaxonomicNode(first, self.root.size, self.root)
|
32 |
+
self.root.size += 1
|
33 |
+
|
34 |
+
self._children[first].add(rest)
|
35 |
+
|
36 |
+
def children(self, name):
|
37 |
+
if not name:
|
38 |
+
return set((child.name, child.index) for child in self._children.values())
|
39 |
+
|
40 |
+
first, rest = name[0], name[1:]
|
41 |
+
if first not in self._children:
|
42 |
+
return set()
|
43 |
+
|
44 |
+
return self._children[first].children(rest)
|
45 |
+
|
46 |
+
def descendants(self, prefix=None):
|
47 |
+
"""Iterates over all values in the subtree that match prefix."""
|
48 |
+
|
49 |
+
if not prefix:
|
50 |
+
yield (self.name,), self.index
|
51 |
+
for child in self._children.values():
|
52 |
+
for name, i in child.descendants():
|
53 |
+
yield (self.name, *name), i
|
54 |
+
return
|
55 |
+
|
56 |
+
first, rest = prefix[0], prefix[1:]
|
57 |
+
if first not in self._children:
|
58 |
+
return
|
59 |
+
|
60 |
+
for name, i in self._children[first].descendants(rest):
|
61 |
+
yield (self.name, *name), i
|
62 |
+
|
63 |
+
def values(self):
|
64 |
+
"""Iterates over all (name, i) pairs in the tree."""
|
65 |
+
yield (self.name,), self.index
|
66 |
+
|
67 |
+
for child in self._children.values():
|
68 |
+
for name, index in child.values():
|
69 |
+
yield (self.name, *name), index
|
70 |
+
|
71 |
+
@classmethod
|
72 |
+
def from_dict(cls, dct, root):
|
73 |
+
node = cls(dct["name"], dct["index"], root)
|
74 |
+
node._children = {
|
75 |
+
child["name"]: cls.from_dict(child, root) for child in dct["children"]
|
76 |
+
}
|
77 |
+
return node
|
78 |
+
|
79 |
+
|
80 |
+
class TaxonomicTree:
|
81 |
+
"""
|
82 |
+
Efficient structure for finding taxonomic names and their descendants.
|
83 |
+
Also returns an integer index i for each possible name.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self):
|
87 |
+
self.kingdoms = {}
|
88 |
+
self.size = 0
|
89 |
+
|
90 |
+
def add(self, name: list[str]):
|
91 |
+
if not name:
|
92 |
+
return
|
93 |
+
|
94 |
+
first, rest = name[0], name[1:]
|
95 |
+
if first not in self.kingdoms:
|
96 |
+
self.kingdoms[first] = TaxonomicNode(first, self.size, self)
|
97 |
+
self.size += 1
|
98 |
+
|
99 |
+
self.kingdoms[first].add(rest)
|
100 |
+
|
101 |
+
def children(self, name=None):
|
102 |
+
if not name:
|
103 |
+
return set(
|
104 |
+
(kingdom.name, kingdom.index) for kingdom in self.kingdoms.values()
|
105 |
+
)
|
106 |
+
|
107 |
+
first, rest = name[0], name[1:]
|
108 |
+
if first not in self.kingdoms:
|
109 |
+
return set()
|
110 |
+
|
111 |
+
return self.kingdoms[first].children(rest)
|
112 |
+
|
113 |
+
def descendants(self, prefix=None):
|
114 |
+
"""Iterates over all values in the tree that match prefix."""
|
115 |
+
if not prefix:
|
116 |
+
# Give them all the subnodes
|
117 |
+
for kingdom in self.kingdoms.values():
|
118 |
+
yield from kingdom.descendants()
|
119 |
+
|
120 |
+
return
|
121 |
+
|
122 |
+
first, rest = prefix[0], prefix[1:]
|
123 |
+
if first not in self.kingdoms:
|
124 |
+
return
|
125 |
+
|
126 |
+
yield from self.kingdoms[first].descendants(rest)
|
127 |
+
|
128 |
+
def values(self):
|
129 |
+
"""Iterates over all (name, i) pairs in the tree."""
|
130 |
+
for kingdom in self.kingdoms.values():
|
131 |
+
yield from kingdom.values()
|
132 |
+
|
133 |
+
def __len__(self):
|
134 |
+
return self.size
|
135 |
+
|
136 |
+
@classmethod
|
137 |
+
def from_dict(cls, dct):
|
138 |
+
tree = cls()
|
139 |
+
tree.kingdoms = {
|
140 |
+
kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree)
|
141 |
+
for kingdom in dct["kingdoms"]
|
142 |
+
}
|
143 |
+
tree.size = dct["size"]
|
144 |
+
return tree
|
145 |
+
|
146 |
+
|
147 |
+
class TaxonomicJsonEncoder(json.JSONEncoder):
|
148 |
+
def default(self, obj):
|
149 |
+
if isinstance(obj, TaxonomicNode):
|
150 |
+
return {
|
151 |
+
"name": obj.name,
|
152 |
+
"index": obj.index,
|
153 |
+
"children": list(obj._children.values()),
|
154 |
+
}
|
155 |
+
elif isinstance(obj, TaxonomicTree):
|
156 |
+
return {
|
157 |
+
"kingdoms": list(obj.kingdoms.values()),
|
158 |
+
"size": obj.size,
|
159 |
+
}
|
160 |
+
else:
|
161 |
+
super().default(self, obj)
|
162 |
+
|
163 |
+
|
164 |
+
def batched(iterable, n):
|
165 |
+
# batched('ABCDEFG', 3) --> ABC DEF G
|
166 |
+
if n < 1:
|
167 |
+
raise ValueError("n must be at least one")
|
168 |
+
it = iter(iterable)
|
169 |
+
while batch := tuple(itertools.islice(it, n)):
|
170 |
+
yield zip(*batch)
|
make_txt_embedding.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Makes the entire set of text emebeddings for all possible names in the tree of life.
|
3 |
+
Uses the catalog.csv file from TreeOfLife-10M.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import csv
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import logging
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from open_clip import create_model, get_tokenizer
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
import lib
|
19 |
+
from templates import openai_imagenet_template
|
20 |
+
|
21 |
+
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
|
22 |
+
logging.basicConfig(level=logging.INFO, format=log_format)
|
23 |
+
logger = logging.getLogger()
|
24 |
+
|
25 |
+
model_str = "hf-hub:imageomics/bioclip"
|
26 |
+
tokenizer_str = "ViT-B-16"
|
27 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
28 |
+
|
29 |
+
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
|
30 |
+
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def write_txt_features(name_lookup):
|
34 |
+
if os.path.isfile(args.out_path):
|
35 |
+
all_features = np.load(args.out_path)
|
36 |
+
else:
|
37 |
+
all_features = np.zeros((512, len(name_lookup)), dtype=np.float32)
|
38 |
+
|
39 |
+
batch_size = args.batch_size // len(openai_imagenet_template)
|
40 |
+
for batch, (names, indices) in enumerate(
|
41 |
+
tqdm(
|
42 |
+
lib.batched(name_lookup.values(), batch_size),
|
43 |
+
desc="txt feats",
|
44 |
+
total=len(name_lookup) // batch_size,
|
45 |
+
)
|
46 |
+
):
|
47 |
+
# Skip if any non-zero elements
|
48 |
+
if all_features[:, indices].any():
|
49 |
+
logger.info(f"Skipping batch {batch}")
|
50 |
+
continue
|
51 |
+
|
52 |
+
txts = [
|
53 |
+
template(name) for name in names for template in openai_imagenet_template
|
54 |
+
]
|
55 |
+
txts = tokenizer(txts).to(device)
|
56 |
+
txt_features = model.encode_text(txts)
|
57 |
+
txt_features = torch.reshape(
|
58 |
+
txt_features, (len(names), len(openai_imagenet_template), 512)
|
59 |
+
)
|
60 |
+
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
|
61 |
+
txt_features /= txt_features.norm(dim=1, keepdim=True)
|
62 |
+
all_features[:, indices] = txt_features.T.cpu().numpy()
|
63 |
+
|
64 |
+
if batch % 100 == 0:
|
65 |
+
np.save(args.out_path, all_features)
|
66 |
+
|
67 |
+
np.save(args.out_path, all_features)
|
68 |
+
|
69 |
+
|
70 |
+
def convert_txt_features_to_avgs(name_lookup):
|
71 |
+
assert os.path.isfile(args.out_path)
|
72 |
+
|
73 |
+
# Put that big boy on the GPU. We're going fast.
|
74 |
+
all_features = torch.from_numpy(np.load(args.out_path)).to(device)
|
75 |
+
logger.info("Loaded text features from disk to %s.", device)
|
76 |
+
|
77 |
+
names_by_rank = [set() for rank in ranks]
|
78 |
+
for name, index in tqdm(name_lookup.values()):
|
79 |
+
i = len(name) - 1
|
80 |
+
names_by_rank[i].add((name, index))
|
81 |
+
|
82 |
+
zeroed = 0
|
83 |
+
for i, rank in reversed(list(enumerate(ranks))):
|
84 |
+
if rank == "Species":
|
85 |
+
continue
|
86 |
+
for name, index in tqdm(names_by_rank[i], desc=rank):
|
87 |
+
species = tuple(
|
88 |
+
zip(
|
89 |
+
*(
|
90 |
+
(d, i)
|
91 |
+
for d, i in name_lookup.descendants(prefix=name)
|
92 |
+
if len(d) >= 6
|
93 |
+
)
|
94 |
+
)
|
95 |
+
)
|
96 |
+
if not species:
|
97 |
+
logger.warning("No species for %s.", " ".join(name))
|
98 |
+
all_features[:, index] = 0.0
|
99 |
+
zeroed += 1
|
100 |
+
continue
|
101 |
+
|
102 |
+
values, indices = species
|
103 |
+
mean = all_features[:, indices].mean(dim=1)
|
104 |
+
all_features[:, index] = F.normalize(mean, dim=0)
|
105 |
+
|
106 |
+
out_path, ext = os.path.splitext(args.out_path)
|
107 |
+
np.save(f"{out_path}_avgs{ext}", all_features.cpu().numpy())
|
108 |
+
if zeroed:
|
109 |
+
logger.warning(
|
110 |
+
"Zeroed out %d nodes because they didn't have any genus or species-level labels.",
|
111 |
+
zeroed,
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
def convert_txt_features_to_species_only(name_lookup):
|
116 |
+
assert os.path.isfile(args.out_path)
|
117 |
+
|
118 |
+
all_features = np.load(args.out_path)
|
119 |
+
logger.info("Loaded text features from disk.")
|
120 |
+
|
121 |
+
species = [(d, i) for d, i in name_lookup.descendants() if len(d) == 7]
|
122 |
+
species_features = np.zeros((512, len(species)), dtype=np.float32)
|
123 |
+
species_names = [""] * len(species)
|
124 |
+
|
125 |
+
for new_i, (name, old_i) in enumerate(tqdm(species)):
|
126 |
+
species_features[:, new_i] = all_features[:, old_i]
|
127 |
+
species_names[new_i] = name
|
128 |
+
|
129 |
+
out_path, ext = os.path.splitext(args.out_path)
|
130 |
+
np.save(f"{out_path}_species{ext}", species_features)
|
131 |
+
with open(f"{out_path}_species.json", "w") as fd:
|
132 |
+
json.dump(species_names, fd, indent=2)
|
133 |
+
|
134 |
+
|
135 |
+
def get_name_lookup(catalog_path, cache_path):
|
136 |
+
if os.path.isfile(cache_path):
|
137 |
+
with open(cache_path) as fd:
|
138 |
+
lookup = lib.TaxonomicTree.from_dict(json.load(fd))
|
139 |
+
return lookup
|
140 |
+
|
141 |
+
lookup = lib.TaxonomicTree()
|
142 |
+
|
143 |
+
with open(catalog_path) as fd:
|
144 |
+
reader = csv.DictReader(fd)
|
145 |
+
for row in tqdm(reader, desc="catalog"):
|
146 |
+
name = [
|
147 |
+
row["kingdom"],
|
148 |
+
row["phylum"],
|
149 |
+
row["class"],
|
150 |
+
row["order"],
|
151 |
+
row["family"],
|
152 |
+
row["genus"],
|
153 |
+
row["species"],
|
154 |
+
]
|
155 |
+
if any(not value for value in name):
|
156 |
+
name = name[: name.index("")]
|
157 |
+
lookup.add(name)
|
158 |
+
|
159 |
+
with open(args.name_cache_path, "w") as fd:
|
160 |
+
json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder)
|
161 |
+
|
162 |
+
return lookup
|
163 |
+
|
164 |
+
|
165 |
+
if __name__ == "__main__":
|
166 |
+
parser = argparse.ArgumentParser()
|
167 |
+
parser.add_argument(
|
168 |
+
"--catalog-path",
|
169 |
+
help="Path to the catalog.csv file from TreeOfLife-10M.",
|
170 |
+
required=True,
|
171 |
+
)
|
172 |
+
parser.add_argument("--out-path", help="Path to the output file.", required=True)
|
173 |
+
parser.add_argument(
|
174 |
+
"--name-cache-path",
|
175 |
+
help="Path to the name cache file.",
|
176 |
+
default="name_lookup.json",
|
177 |
+
)
|
178 |
+
parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int)
|
179 |
+
args = parser.parse_args()
|
180 |
+
|
181 |
+
name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
|
182 |
+
logger.info("Got name lookup.")
|
183 |
+
|
184 |
+
model = create_model(model_str, output_dict=True, require_pretrained=True)
|
185 |
+
model = model.to(device)
|
186 |
+
logger.info("Created model.")
|
187 |
+
model = torch.compile(model)
|
188 |
+
logger.info("Compiled model.")
|
189 |
+
|
190 |
+
tokenizer = get_tokenizer(tokenizer_str)
|
191 |
+
write_txt_features(name_lookup)
|
192 |
+
convert_txt_features_to_avgs(name_lookup)
|
193 |
+
convert_txt_features_to_species_only(name_lookup)
|
name_lookup.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:20d731d9d901f1c17927187bc87e4a2513279845a1a6ba5982dbf779f2ac1434
|
3 |
+
size 26462858
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
open_clip_torch
|
2 |
+
torchvision
|
3 |
+
torch
|
4 |
+
gradio
|
5 |
+
polars
|
6 |
+
pillow
|
7 |
+
boto3
|
templates.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai_imagenet_template = [
|
2 |
+
lambda c: f"a bad photo of a {c}.",
|
3 |
+
lambda c: f"a photo of many {c}.",
|
4 |
+
lambda c: f"a sculpture of a {c}.",
|
5 |
+
lambda c: f"a photo of the hard to see {c}.",
|
6 |
+
lambda c: f"a low resolution photo of the {c}.",
|
7 |
+
lambda c: f"a rendering of a {c}.",
|
8 |
+
lambda c: f"graffiti of a {c}.",
|
9 |
+
lambda c: f"a bad photo of the {c}.",
|
10 |
+
lambda c: f"a cropped photo of the {c}.",
|
11 |
+
lambda c: f"a tattoo of a {c}.",
|
12 |
+
lambda c: f"the embroidered {c}.",
|
13 |
+
lambda c: f"a photo of a hard to see {c}.",
|
14 |
+
lambda c: f"a bright photo of a {c}.",
|
15 |
+
lambda c: f"a photo of a clean {c}.",
|
16 |
+
lambda c: f"a photo of a dirty {c}.",
|
17 |
+
lambda c: f"a dark photo of the {c}.",
|
18 |
+
lambda c: f"a drawing of a {c}.",
|
19 |
+
lambda c: f"a photo of my {c}.",
|
20 |
+
lambda c: f"the plastic {c}.",
|
21 |
+
lambda c: f"a photo of the cool {c}.",
|
22 |
+
lambda c: f"a close-up photo of a {c}.",
|
23 |
+
lambda c: f"a black and white photo of the {c}.",
|
24 |
+
lambda c: f"a painting of the {c}.",
|
25 |
+
lambda c: f"a painting of a {c}.",
|
26 |
+
lambda c: f"a pixelated photo of the {c}.",
|
27 |
+
lambda c: f"a sculpture of the {c}.",
|
28 |
+
lambda c: f"a bright photo of the {c}.",
|
29 |
+
lambda c: f"a cropped photo of a {c}.",
|
30 |
+
lambda c: f"a plastic {c}.",
|
31 |
+
lambda c: f"a photo of the dirty {c}.",
|
32 |
+
lambda c: f"a jpeg corrupted photo of a {c}.",
|
33 |
+
lambda c: f"a blurry photo of the {c}.",
|
34 |
+
lambda c: f"a photo of the {c}.",
|
35 |
+
lambda c: f"a good photo of the {c}.",
|
36 |
+
lambda c: f"a rendering of the {c}.",
|
37 |
+
lambda c: f"a {c} in a video game.",
|
38 |
+
lambda c: f"a photo of one {c}.",
|
39 |
+
lambda c: f"a doodle of a {c}.",
|
40 |
+
lambda c: f"a close-up photo of the {c}.",
|
41 |
+
lambda c: f"a photo of a {c}.",
|
42 |
+
lambda c: f"the origami {c}.",
|
43 |
+
lambda c: f"the {c} in a video game.",
|
44 |
+
lambda c: f"a sketch of a {c}.",
|
45 |
+
lambda c: f"a doodle of the {c}.",
|
46 |
+
lambda c: f"a origami {c}.",
|
47 |
+
lambda c: f"a low resolution photo of a {c}.",
|
48 |
+
lambda c: f"the toy {c}.",
|
49 |
+
lambda c: f"a rendition of the {c}.",
|
50 |
+
lambda c: f"a photo of the clean {c}.",
|
51 |
+
lambda c: f"a photo of a large {c}.",
|
52 |
+
lambda c: f"a rendition of a {c}.",
|
53 |
+
lambda c: f"a photo of a nice {c}.",
|
54 |
+
lambda c: f"a photo of a weird {c}.",
|
55 |
+
lambda c: f"a blurry photo of a {c}.",
|
56 |
+
lambda c: f"a cartoon {c}.",
|
57 |
+
lambda c: f"art of a {c}.",
|
58 |
+
lambda c: f"a sketch of the {c}.",
|
59 |
+
lambda c: f"a embroidered {c}.",
|
60 |
+
lambda c: f"a pixelated photo of a {c}.",
|
61 |
+
lambda c: f"itap of the {c}.",
|
62 |
+
lambda c: f"a jpeg corrupted photo of the {c}.",
|
63 |
+
lambda c: f"a good photo of a {c}.",
|
64 |
+
lambda c: f"a plushie {c}.",
|
65 |
+
lambda c: f"a photo of the nice {c}.",
|
66 |
+
lambda c: f"a photo of the small {c}.",
|
67 |
+
lambda c: f"a photo of the weird {c}.",
|
68 |
+
lambda c: f"the cartoon {c}.",
|
69 |
+
lambda c: f"art of the {c}.",
|
70 |
+
lambda c: f"a drawing of the {c}.",
|
71 |
+
lambda c: f"a photo of the large {c}.",
|
72 |
+
lambda c: f"a black and white photo of a {c}.",
|
73 |
+
lambda c: f"the plushie {c}.",
|
74 |
+
lambda c: f"a dark photo of a {c}.",
|
75 |
+
lambda c: f"itap of a {c}.",
|
76 |
+
lambda c: f"graffiti of the {c}.",
|
77 |
+
lambda c: f"a toy {c}.",
|
78 |
+
lambda c: f"itap of my {c}.",
|
79 |
+
lambda c: f"a photo of a cool {c}.",
|
80 |
+
lambda c: f"a photo of a small {c}.",
|
81 |
+
lambda c: f"a tattoo of the {c}.",
|
82 |
+
]
|
test_lib.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lib
|
2 |
+
|
3 |
+
|
4 |
+
def test_taxonomiclookup_empty():
|
5 |
+
lookup = lib.TaxonomicTree()
|
6 |
+
assert lookup.size == 0
|
7 |
+
|
8 |
+
|
9 |
+
def test_taxonomiclookup_kingdom_size():
|
10 |
+
lookup = lib.TaxonomicTree()
|
11 |
+
|
12 |
+
lookup.add(("Animalia",))
|
13 |
+
|
14 |
+
assert lookup.size == 1
|
15 |
+
|
16 |
+
|
17 |
+
def test_taxonomiclookup_genus_size():
|
18 |
+
lookup = lib.TaxonomicTree()
|
19 |
+
|
20 |
+
lookup.add(
|
21 |
+
(
|
22 |
+
"Animalia",
|
23 |
+
"Chordata",
|
24 |
+
"Aves",
|
25 |
+
"Accipitriformes",
|
26 |
+
"Accipitridae",
|
27 |
+
"Halieaeetus",
|
28 |
+
)
|
29 |
+
)
|
30 |
+
|
31 |
+
assert lookup.size == 6
|
32 |
+
|
33 |
+
|
34 |
+
def test_taxonomictree_kingdom_children():
|
35 |
+
lookup = lib.TaxonomicTree()
|
36 |
+
|
37 |
+
lookup.add(
|
38 |
+
(
|
39 |
+
"Animalia",
|
40 |
+
"Chordata",
|
41 |
+
"Aves",
|
42 |
+
"Accipitriformes",
|
43 |
+
"Accipitridae",
|
44 |
+
"Halieaeetus",
|
45 |
+
)
|
46 |
+
)
|
47 |
+
|
48 |
+
expected = set([("Animalia", 0)])
|
49 |
+
actual = lookup.children()
|
50 |
+
assert actual == expected
|
51 |
+
|
52 |
+
|
53 |
+
def test_taxonomiclookup_children_of_animal_only_birds():
|
54 |
+
lookup = lib.TaxonomicTree()
|
55 |
+
|
56 |
+
lookup.add(
|
57 |
+
(
|
58 |
+
"Animalia",
|
59 |
+
"Chordata",
|
60 |
+
"Aves",
|
61 |
+
"Accipitriformes",
|
62 |
+
"Accipitridae",
|
63 |
+
"Halieaeetus",
|
64 |
+
"leucocephalus",
|
65 |
+
)
|
66 |
+
)
|
67 |
+
lookup.add(
|
68 |
+
(
|
69 |
+
"Animalia",
|
70 |
+
"Chordata",
|
71 |
+
"Aves",
|
72 |
+
"Strigiformes",
|
73 |
+
"Strigidae",
|
74 |
+
"Ninox",
|
75 |
+
"scutulata",
|
76 |
+
)
|
77 |
+
)
|
78 |
+
lookup.add(
|
79 |
+
(
|
80 |
+
"Animalia",
|
81 |
+
"Chordata",
|
82 |
+
"Aves",
|
83 |
+
"Strigiformes",
|
84 |
+
"Strigidae",
|
85 |
+
"Ninox",
|
86 |
+
"plesseni",
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
actual = lookup.children(("Animalia",))
|
91 |
+
expected = set([("Chordata", 1)])
|
92 |
+
assert actual == expected
|
93 |
+
|
94 |
+
|
95 |
+
def test_taxonomiclookup_children_of_animal():
|
96 |
+
lookup = lib.TaxonomicTree()
|
97 |
+
|
98 |
+
lookup.add(
|
99 |
+
(
|
100 |
+
"Animalia",
|
101 |
+
"Chordata",
|
102 |
+
"Aves",
|
103 |
+
"Accipitriformes",
|
104 |
+
"Accipitridae",
|
105 |
+
"Halieaeetus",
|
106 |
+
"leucocephalus",
|
107 |
+
)
|
108 |
+
)
|
109 |
+
lookup.add(
|
110 |
+
(
|
111 |
+
"Animalia",
|
112 |
+
"Chordata",
|
113 |
+
"Aves",
|
114 |
+
"Strigiformes",
|
115 |
+
"Strigidae",
|
116 |
+
"Ninox",
|
117 |
+
"scutulata",
|
118 |
+
)
|
119 |
+
)
|
120 |
+
lookup.add(
|
121 |
+
(
|
122 |
+
"Animalia",
|
123 |
+
"Chordata",
|
124 |
+
"Aves",
|
125 |
+
"Strigiformes",
|
126 |
+
"Strigidae",
|
127 |
+
"Ninox",
|
128 |
+
"plesseni",
|
129 |
+
)
|
130 |
+
)
|
131 |
+
lookup.add(
|
132 |
+
(
|
133 |
+
"Animalia",
|
134 |
+
"Chordata",
|
135 |
+
"Mammalia",
|
136 |
+
"Primates",
|
137 |
+
"Hominidae",
|
138 |
+
"Gorilla",
|
139 |
+
"gorilla",
|
140 |
+
)
|
141 |
+
)
|
142 |
+
lookup.add(
|
143 |
+
(
|
144 |
+
"Animalia",
|
145 |
+
"Arthropoda",
|
146 |
+
"Insecta",
|
147 |
+
"Hymenoptera",
|
148 |
+
"Apidae",
|
149 |
+
"Bombus",
|
150 |
+
"balteatus",
|
151 |
+
)
|
152 |
+
)
|
153 |
+
|
154 |
+
actual = lookup.children(("Animalia",))
|
155 |
+
expected = set([("Chordata", 1), ("Arthropoda", 17)])
|
156 |
+
assert actual == expected
|
157 |
+
|
158 |
+
|
159 |
+
def test_taxonomiclookup_children_of_chordata():
|
160 |
+
lookup = lib.TaxonomicTree()
|
161 |
+
|
162 |
+
lookup.add(
|
163 |
+
(
|
164 |
+
"Animalia",
|
165 |
+
"Chordata",
|
166 |
+
"Aves",
|
167 |
+
"Accipitriformes",
|
168 |
+
"Accipitridae",
|
169 |
+
"Halieaeetus",
|
170 |
+
"leucocephalus",
|
171 |
+
)
|
172 |
+
)
|
173 |
+
lookup.add(
|
174 |
+
(
|
175 |
+
"Animalia",
|
176 |
+
"Chordata",
|
177 |
+
"Aves",
|
178 |
+
"Strigiformes",
|
179 |
+
"Strigidae",
|
180 |
+
"Ninox",
|
181 |
+
"scutulata",
|
182 |
+
)
|
183 |
+
)
|
184 |
+
lookup.add(
|
185 |
+
(
|
186 |
+
"Animalia",
|
187 |
+
"Chordata",
|
188 |
+
"Aves",
|
189 |
+
"Strigiformes",
|
190 |
+
"Strigidae",
|
191 |
+
"Ninox",
|
192 |
+
"plesseni",
|
193 |
+
)
|
194 |
+
)
|
195 |
+
lookup.add(
|
196 |
+
(
|
197 |
+
"Animalia",
|
198 |
+
"Chordata",
|
199 |
+
"Mammalia",
|
200 |
+
"Primates",
|
201 |
+
"Hominidae",
|
202 |
+
"Gorilla",
|
203 |
+
"gorilla",
|
204 |
+
)
|
205 |
+
)
|
206 |
+
lookup.add(
|
207 |
+
(
|
208 |
+
"Animalia",
|
209 |
+
"Arthropoda",
|
210 |
+
"Insecta",
|
211 |
+
"Hymenoptera",
|
212 |
+
"Apidae",
|
213 |
+
"Bombus",
|
214 |
+
"balteatus",
|
215 |
+
)
|
216 |
+
)
|
217 |
+
|
218 |
+
actual = lookup.children(("Animalia", "Chordata"))
|
219 |
+
expected = set([("Aves", 2), ("Mammalia", 12)])
|
220 |
+
assert actual == expected
|
221 |
+
|
222 |
+
|
223 |
+
def test_taxonomiclookup_children_of_strigiformes():
|
224 |
+
lookup = lib.TaxonomicTree()
|
225 |
+
|
226 |
+
lookup.add(
|
227 |
+
(
|
228 |
+
"Animalia",
|
229 |
+
"Chordata",
|
230 |
+
"Aves",
|
231 |
+
"Accipitriformes",
|
232 |
+
"Accipitridae",
|
233 |
+
"Halieaeetus",
|
234 |
+
"leucocephalus",
|
235 |
+
)
|
236 |
+
)
|
237 |
+
lookup.add(
|
238 |
+
(
|
239 |
+
"Animalia",
|
240 |
+
"Chordata",
|
241 |
+
"Aves",
|
242 |
+
"Strigiformes",
|
243 |
+
"Strigidae",
|
244 |
+
"Ninox",
|
245 |
+
"scutulata",
|
246 |
+
)
|
247 |
+
)
|
248 |
+
lookup.add(
|
249 |
+
(
|
250 |
+
"Animalia",
|
251 |
+
"Chordata",
|
252 |
+
"Aves",
|
253 |
+
"Strigiformes",
|
254 |
+
"Strigidae",
|
255 |
+
"Ninox",
|
256 |
+
"plesseni",
|
257 |
+
)
|
258 |
+
)
|
259 |
+
lookup.add(
|
260 |
+
(
|
261 |
+
"Animalia",
|
262 |
+
"Chordata",
|
263 |
+
"Mammalia",
|
264 |
+
"Primates",
|
265 |
+
"Hominidae",
|
266 |
+
"Gorilla",
|
267 |
+
"gorilla",
|
268 |
+
)
|
269 |
+
)
|
270 |
+
lookup.add(
|
271 |
+
(
|
272 |
+
"Animalia",
|
273 |
+
"Arthropoda",
|
274 |
+
"Insecta",
|
275 |
+
"Hymenoptera",
|
276 |
+
"Apidae",
|
277 |
+
"Bombus",
|
278 |
+
"balteatus",
|
279 |
+
)
|
280 |
+
)
|
281 |
+
|
282 |
+
actual = lookup.children(("Animalia", "Chordata", "Aves", "Strigiformes"))
|
283 |
+
expected = set([("Strigidae", 8)])
|
284 |
+
assert actual == expected
|
285 |
+
|
286 |
+
|
287 |
+
def test_taxonomiclookup_children_of_ninox():
|
288 |
+
lookup = lib.TaxonomicTree()
|
289 |
+
|
290 |
+
lookup.add(
|
291 |
+
(
|
292 |
+
"Animalia",
|
293 |
+
"Chordata",
|
294 |
+
"Aves",
|
295 |
+
"Accipitriformes",
|
296 |
+
"Accipitridae",
|
297 |
+
"Halieaeetus",
|
298 |
+
"leucocephalus",
|
299 |
+
)
|
300 |
+
)
|
301 |
+
lookup.add(
|
302 |
+
(
|
303 |
+
"Animalia",
|
304 |
+
"Chordata",
|
305 |
+
"Aves",
|
306 |
+
"Strigiformes",
|
307 |
+
"Strigidae",
|
308 |
+
"Ninox",
|
309 |
+
"scutulata",
|
310 |
+
)
|
311 |
+
)
|
312 |
+
lookup.add(
|
313 |
+
(
|
314 |
+
"Animalia",
|
315 |
+
"Chordata",
|
316 |
+
"Aves",
|
317 |
+
"Strigiformes",
|
318 |
+
"Strigidae",
|
319 |
+
"Ninox",
|
320 |
+
"plesseni",
|
321 |
+
)
|
322 |
+
)
|
323 |
+
lookup.add(
|
324 |
+
(
|
325 |
+
"Animalia",
|
326 |
+
"Chordata",
|
327 |
+
"Mammalia",
|
328 |
+
"Primates",
|
329 |
+
"Hominidae",
|
330 |
+
"Gorilla",
|
331 |
+
"gorilla",
|
332 |
+
)
|
333 |
+
)
|
334 |
+
lookup.add(
|
335 |
+
(
|
336 |
+
"Animalia",
|
337 |
+
"Arthropoda",
|
338 |
+
"Insecta",
|
339 |
+
"Hymenoptera",
|
340 |
+
"Apidae",
|
341 |
+
"Bombus",
|
342 |
+
"balteatus",
|
343 |
+
)
|
344 |
+
)
|
345 |
+
|
346 |
+
actual = lookup.children(
|
347 |
+
("Animalia", "Chordata", "Aves", "Strigiformes", "Strigidae", "Ninox")
|
348 |
+
)
|
349 |
+
expected = set([("scutulata", 10), ("plesseni", 11)])
|
350 |
+
assert actual == expected
|
351 |
+
|
352 |
+
|
353 |
+
def test_taxonomiclookup_children_of_gorilla():
|
354 |
+
lookup = lib.TaxonomicTree()
|
355 |
+
|
356 |
+
lookup.add(
|
357 |
+
(
|
358 |
+
"Animalia",
|
359 |
+
"Chordata",
|
360 |
+
"Aves",
|
361 |
+
"Accipitriformes",
|
362 |
+
"Accipitridae",
|
363 |
+
"Halieaeetus",
|
364 |
+
"leucocephalus",
|
365 |
+
)
|
366 |
+
)
|
367 |
+
lookup.add(
|
368 |
+
(
|
369 |
+
"Animalia",
|
370 |
+
"Chordata",
|
371 |
+
"Aves",
|
372 |
+
"Strigiformes",
|
373 |
+
"Strigidae",
|
374 |
+
"Ninox",
|
375 |
+
"scutulata",
|
376 |
+
)
|
377 |
+
)
|
378 |
+
lookup.add(
|
379 |
+
(
|
380 |
+
"Animalia",
|
381 |
+
"Chordata",
|
382 |
+
"Aves",
|
383 |
+
"Strigiformes",
|
384 |
+
"Strigidae",
|
385 |
+
"Ninox",
|
386 |
+
"plesseni",
|
387 |
+
)
|
388 |
+
)
|
389 |
+
lookup.add(
|
390 |
+
(
|
391 |
+
"Animalia",
|
392 |
+
"Chordata",
|
393 |
+
"Mammalia",
|
394 |
+
"Primates",
|
395 |
+
"Hominidae",
|
396 |
+
"Gorilla",
|
397 |
+
"gorilla",
|
398 |
+
)
|
399 |
+
)
|
400 |
+
lookup.add(
|
401 |
+
(
|
402 |
+
"Animalia",
|
403 |
+
"Arthropoda",
|
404 |
+
"Insecta",
|
405 |
+
"Hymenoptera",
|
406 |
+
"Apidae",
|
407 |
+
"Bombus",
|
408 |
+
"balteatus",
|
409 |
+
)
|
410 |
+
)
|
411 |
+
|
412 |
+
actual = lookup.children(
|
413 |
+
(
|
414 |
+
"Animalia",
|
415 |
+
"Chordata",
|
416 |
+
"Mammalia",
|
417 |
+
"Primates",
|
418 |
+
"Hominidae",
|
419 |
+
"Gorilla",
|
420 |
+
"gorilla",
|
421 |
+
)
|
422 |
+
)
|
423 |
+
expected = set()
|
424 |
+
assert actual == expected
|
425 |
+
|
426 |
+
|
427 |
+
def test_taxonomictree_descendants_last():
|
428 |
+
lookup = lib.TaxonomicTree()
|
429 |
+
|
430 |
+
lookup.add(("A", "B", "C", "D", "E", "F", "G"))
|
431 |
+
|
432 |
+
actual = list(lookup.descendants(("A", "B", "C", "D", "E", "F", "G")))
|
433 |
+
|
434 |
+
expected = [
|
435 |
+
(("A", "B", "C", "D", "E", "F", "G"), 6),
|
436 |
+
]
|
437 |
+
assert actual == expected
|
438 |
+
|
439 |
+
|
440 |
+
def test_taxonomictree_descendants_entire_tree():
|
441 |
+
lookup = lib.TaxonomicTree()
|
442 |
+
|
443 |
+
lookup.add(("A", "B"))
|
444 |
+
|
445 |
+
actual = list(lookup.descendants())
|
446 |
+
|
447 |
+
expected = [
|
448 |
+
(("A",), 0),
|
449 |
+
(("A", "B"), 1),
|
450 |
+
]
|
451 |
+
assert actual == expected
|
452 |
+
|
453 |
+
|
454 |
+
def test_taxonomictree_descendants_entire_tree_with_prefix():
|
455 |
+
lookup = lib.TaxonomicTree()
|
456 |
+
|
457 |
+
lookup.add(("A", "B"))
|
458 |
+
|
459 |
+
actual = list(lookup.descendants(prefix=("A",)))
|
460 |
+
|
461 |
+
expected = [
|
462 |
+
(("A",), 0),
|
463 |
+
(("A", "B"), 1),
|
464 |
+
]
|
465 |
+
assert actual == expected
|
466 |
+
|
467 |
+
|
468 |
+
def test_taxonomictree_descendants_general():
|
469 |
+
lookup = lib.TaxonomicTree()
|
470 |
+
|
471 |
+
lookup.add(("A", "B", "C", "D", "E", "F", "G"))
|
472 |
+
|
473 |
+
actual = list(lookup.descendants(("A", "B", "C", "D")))
|
474 |
+
|
475 |
+
expected = [
|
476 |
+
(("A", "B", "C", "D"), 3),
|
477 |
+
(("A", "B", "C", "D", "E"), 4),
|
478 |
+
(("A", "B", "C", "D", "E", "F"), 5),
|
479 |
+
(("A", "B", "C", "D", "E", "F", "G"), 6),
|
480 |
+
]
|
481 |
+
assert actual == expected
|
txt_emb.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4a3c3412c3dae49cf92cc760aba5ee84227362adf1eb08f04dd50ee2a756e43
|
3 |
+
size 969818240
|
txt_emb_species.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:844e6fabc06cac072214d566b78f40825b154efa9479eb11285030ca038b2ece
|
3 |
+
size 65731052
|
txt_emb_species.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91ce02dff2433222e3138b8bf7eefa1dd74b30f4d406c16cd3301f66d65ab4ed
|
3 |
+
size 787435648
|