Ben Feuer commited on
Commit
ec3c973
1 Parent(s): 1e0e251

BioTrove Demo

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.json filter=lfs diff=lfs merge=lfs -text
37
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
39
+ components/metadata.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/
2
+ __pycache__/
README.md CHANGED
@@ -1,14 +1,11 @@
1
  ---
2
  title: BioTrove Demo
3
- emoji: 🦀
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: A new vision-language model for zero-shot species detection
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: BioTrove Demo
3
+ emoji: 🐘
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ ---
 
 
 
app.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import heapq
3
+ import json
4
+ import os
5
+ import logging
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import polars as pl
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from open_clip import create_model, get_tokenizer
13
+ from torchvision import transforms
14
+
15
+ from templates import openai_imagenet_template
16
+ from components.query import get_sample
17
+
18
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
19
+ logging.basicConfig(level=logging.INFO, format=log_format)
20
+ logger = logging.getLogger()
21
+
22
+ hf_token = os.getenv("HF_TOKEN")
23
+
24
+ # For sample images
25
+ METADATA_PATH = "components/metadata.csv"
26
+ # Read page ID as int and filter out smaller ablation duplicated training split
27
+ metadata_df = pl.read_csv(METADATA_PATH, low_memory = False)
28
+ metadata_df = metadata_df.with_columns(pl.col("eol_page_id").cast(pl.Int64))
29
+
30
+ model_str = "hf-hub:penfever/biotrove-o"
31
+ tokenizer_str = "ViT-B-16"
32
+
33
+ txt_emb_npy = "txt_emb_species.npy"
34
+ txt_names_json = "txt_emb_species.json"
35
+
36
+ min_prob = 1e-9
37
+ k = 5
38
+
39
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
+
41
+ preprocess_img = transforms.Compose(
42
+ [
43
+ transforms.ToTensor(),
44
+ transforms.Resize((224, 224), antialias=True),
45
+ transforms.Normalize(
46
+ mean=(0.48145466, 0.4578275, 0.40821073),
47
+ std=(0.26862954, 0.26130258, 0.27577711),
48
+ ),
49
+ ]
50
+ )
51
+
52
+ ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
53
+
54
+ open_domain_examples = [
55
+ ["examples/Ursus-arctos.jpeg", "Species"],
56
+ ["examples/Phoca-vitulina.png", "Species"],
57
+ ["examples/Felis-catus.jpeg", "Genus"],
58
+ ["examples/Sarcoscypha-coccinea.jpeg", "Order"],
59
+ ]
60
+ zero_shot_examples = [
61
+ [
62
+ "examples/Ursus-arctos.jpeg",
63
+ "brown bear\nblack bear\npolar bear\nkoala bear\ngrizzly bear",
64
+ ],
65
+ ["examples/milk-snake.png", "coral snake\nmilk snake"],
66
+ ["examples/coral-snake.jpeg", "coral snake\nmilk snake"],
67
+ [
68
+ "examples/Carnegiea-gigantea.png",
69
+ "Carnegiea gigantea\nSchlumbergera opuntioides\nMammillaria albicoma",
70
+ ],
71
+ [
72
+ "examples/Amanita-muscaria.jpeg",
73
+ "Amanita fulva\nAmanita vaginata (grisette)\nAmanita calyptrata (coccoli)\nAmanita crocea\nAmanita rubescens (blusher)\nAmanita caesarea (Caesar's mushroom)\nAmanita jacksonii (American Caesar's mushroom)\nAmanita muscaria (fly agaric)\nAmanita pantherina (panther cap)",
74
+ ],
75
+ [
76
+ "examples/Actinostola-abyssorum.png",
77
+ "Animalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola abyssorum\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola bulbosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola callosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola capensis\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola carlgreni",
78
+ ],
79
+ [
80
+ "examples/Sarcoscypha-coccinea.jpeg",
81
+ "scarlet elf cup (coccinea)\nscharlachroter kelchbecherling (austriaca)\ncrimson cup (dudleyi)\nstalked scarlet cup (occidentalis)",
82
+ ],
83
+ [
84
+ "examples/Onoclea-hintonii.jpg",
85
+ "Onoclea attenuata\nOnoclea boryana\nOnoclea hintonii\nOnoclea intermedia\nOnoclea sensibilis",
86
+ ],
87
+ [
88
+ "examples/Onoclea-sensibilis.jpg",
89
+ "Onoclea attenuata\nOnoclea boryana\nOnoclea hintonii\nOnoclea intermedia\nOnoclea sensibilis",
90
+ ],
91
+ ]
92
+
93
+
94
+ def indexed(lst, indices):
95
+ return [lst[i] for i in indices]
96
+
97
+
98
+ @torch.no_grad()
99
+ def get_txt_features(classnames, templates):
100
+ all_features = []
101
+ for classname in classnames:
102
+ txts = [template(classname) for template in templates]
103
+ txts = tokenizer(txts).to(device)
104
+ txt_features = model.encode_text(txts)
105
+ txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
106
+ txt_features /= txt_features.norm()
107
+ all_features.append(txt_features)
108
+ all_features = torch.stack(all_features, dim=1)
109
+ return all_features
110
+
111
+
112
+ @torch.no_grad()
113
+ def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
114
+ classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
115
+ txt_features = get_txt_features(classes, openai_imagenet_template)
116
+
117
+ img = preprocess_img(img).to(device)
118
+ img_features = model.encode_image(img.unsqueeze(0))
119
+ img_features = F.normalize(img_features, dim=-1)
120
+
121
+ logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
122
+ probs = F.softmax(logits, dim=0).to("cpu").tolist()
123
+ return {cls: prob for cls, prob in zip(classes, probs)}
124
+
125
+
126
+ def format_name(taxon, common):
127
+ taxon = " ".join(taxon)
128
+ if not common:
129
+ return taxon
130
+ return f"{taxon} ({common})"
131
+
132
+
133
+ @torch.no_grad()
134
+ def open_domain_classification(img, rank: int, return_all=False):
135
+ """
136
+ Predicts from the entire tree of life.
137
+ If targeting a higher rank than species, then this function predicts among all
138
+ species, then sums up species-level probabilities for the given rank.
139
+ """
140
+
141
+ logger.info(f"Starting open domain classification for rank: {rank}")
142
+ img = preprocess_img(img).to(device)
143
+ img_features = model.encode_image(img.unsqueeze(0))
144
+ img_features = F.normalize(img_features, dim=-1)
145
+
146
+ logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
147
+ probs = F.softmax(logits, dim=0)
148
+
149
+ if rank + 1 == len(ranks):
150
+ topk = probs.topk(k)
151
+ prediction_dict = {
152
+ format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
153
+ }
154
+ logger.info(f"Top K predictions: {prediction_dict}")
155
+ top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
156
+ logger.info(f"Top prediction name: {top_prediction_name}")
157
+ sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
158
+ if return_all:
159
+ return prediction_dict, sample_img, taxon_url
160
+ return prediction_dict
161
+
162
+ output = collections.defaultdict(float)
163
+ for i in torch.nonzero(probs > min_prob).squeeze():
164
+ output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
165
+
166
+ topk_names = heapq.nlargest(k, output, key=output.get)
167
+ prediction_dict = {name: output[name] for name in topk_names}
168
+ logger.info(f"Top K names for output: {topk_names}")
169
+ logger.info(f"Prediction dictionary: {prediction_dict}")
170
+
171
+ top_prediction_name = topk_names[0]
172
+ logger.info(f"Top prediction name: {top_prediction_name}")
173
+ sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
174
+ logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
175
+
176
+ if return_all:
177
+ return prediction_dict, sample_img, taxon_url
178
+ return prediction_dict
179
+
180
+
181
+ def change_output(choice):
182
+ return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ logger.info("Starting.")
187
+ model = create_model(model_str, output_dict=True, require_pretrained=True)
188
+ model = model.to(device)
189
+ logger.info("Created model.")
190
+
191
+ model = torch.compile(model)
192
+ logger.info("Compiled model.")
193
+
194
+ tokenizer = get_tokenizer(tokenizer_str)
195
+
196
+ txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device)
197
+ with open(txt_names_json) as fd:
198
+ txt_names = json.load(fd)
199
+
200
+ done = txt_emb.any(axis=0).sum().item()
201
+ total = txt_emb.shape[1]
202
+ status_msg = ""
203
+ if done != total:
204
+ status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
205
+
206
+ with gr.Blocks() as app:
207
+
208
+ with gr.Tab("Open-Ended"):
209
+ with gr.Row(variant = "panel", elem_id = "images_panel"):
210
+ with gr.Column():
211
+ img_input = gr.Image(height = 400, sources=["upload"])
212
+
213
+ with gr.Column():
214
+ # display sample image of top predicted taxon
215
+ sample_img = gr.Image(label = "Sample Image of Predicted Taxon",
216
+ height = 400,
217
+ show_download_button = False)
218
+
219
+ taxon_url = gr.HTML(label = "More Information",
220
+ elem_id = "url"
221
+ )
222
+
223
+ with gr.Row():
224
+ with gr.Column():
225
+ rank_dropdown = gr.Dropdown(
226
+ label="Taxonomic Rank",
227
+ info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
228
+ choices=ranks,
229
+ value="Species",
230
+ type="index",
231
+ )
232
+ open_domain_btn = gr.Button("Submit", variant="primary")
233
+ with gr.Column():
234
+ open_domain_output = gr.Label(
235
+ num_top_classes=k,
236
+ label="Prediction",
237
+ show_label=True,
238
+ value=None,
239
+ )
240
+ # open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
241
+
242
+ with gr.Row():
243
+ gr.Examples(
244
+ examples=open_domain_examples,
245
+ inputs=[img_input, rank_dropdown],
246
+ cache_examples=True,
247
+ fn=lambda img, rank: open_domain_classification(img, rank, return_all=False),
248
+ outputs=[open_domain_output],
249
+ )
250
+ with gr.Tab("Zero-Shot"):
251
+ with gr.Row():
252
+ img_input_zs = gr.Image(height = 400, sources=["upload"])
253
+
254
+ with gr.Row():
255
+ with gr.Column():
256
+ classes_txt = gr.Textbox(
257
+ placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
258
+ lines=3,
259
+ label="Classes",
260
+ show_label=True,
261
+ info="Use taxonomic names where possible; include common names if possible.",
262
+ )
263
+ zero_shot_btn = gr.Button("Submit", variant="primary")
264
+
265
+ with gr.Column():
266
+ zero_shot_output = gr.Label(
267
+ num_top_classes=k, label="Prediction", show_label=True
268
+ )
269
+ # zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
270
+
271
+ with gr.Row():
272
+ gr.Examples(
273
+ examples=zero_shot_examples,
274
+ inputs=[img_input_zs, classes_txt],
275
+ cache_examples=True,
276
+ fn=zero_shot_classification,
277
+ outputs=[zero_shot_output],
278
+ )
279
+ rank_dropdown.change(
280
+ fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
281
+ )
282
+
283
+ open_domain_btn.click(
284
+ fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
285
+ inputs=[img_input, rank_dropdown],
286
+ outputs=[open_domain_output, sample_img, taxon_url],
287
+ )
288
+
289
+ zero_shot_btn.click(
290
+ fn=zero_shot_classification,
291
+ inputs=[img_input_zs, classes_txt],
292
+ outputs=zero_shot_output,
293
+ )
294
+
295
+ # Footer to point out to model and data from app page.
296
+ gr.Markdown(
297
+ """
298
+ TODO: Add footer with model and data information.
299
+ """
300
+ )
301
+
302
+ app.queue(max_size=20)
303
+ app.launch(share=True)
components/metadata.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8576f6ca106f35387506369a70df01fb92192a740c3b5da2a12ad8303976aad
3
+ size 233934143
components/metadata_readme.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Bioclip Demo
3
+ emoji: 🐘
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
components/query.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import boto3
3
+ import requests
4
+ import numpy as np
5
+ import polars as pl
6
+ from PIL import Image
7
+ from botocore.config import Config
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # S3 for sample images
13
+ my_config = Config(
14
+ region_name='us-east-1'
15
+ )
16
+ s3_client = boto3.client('s3', config=my_config)
17
+
18
+ # Set basepath for EOL pages for info
19
+ EOL_URL = "https://eol.org/pages/"
20
+ RANKS = ["kingdom", "phylum", "class", "order", "family", "genus", "species"]
21
+
22
+ def get_sample(df, pred_taxon, rank):
23
+ '''
24
+ Function to retrieve a sample image of the predicted taxon and EOL page link for more info.
25
+
26
+ Parameters:
27
+ -----------
28
+ df : DataFrame
29
+ DataFrame with all sample images listed and their filepaths (in "file_path" column).
30
+ pred_taxon : str
31
+ Predicted taxon of the uploaded image.
32
+ rank : int
33
+ Index of rank in RANKS chosen for prediction.
34
+
35
+ Returns:
36
+ --------
37
+ img : PIL.Image
38
+ Sample image of predicted taxon for display.
39
+ eol_page : str
40
+ URL to EOL page for the taxon (may be a lower rank, e.g., species sample).
41
+ '''
42
+ logger.info(f"Getting sample for taxon: {pred_taxon} at rank: {rank}")
43
+ try:
44
+ filepath, eol_page_id, full_name, is_exact = get_sample_data(df, pred_taxon, rank)
45
+ except Exception as e:
46
+ logger.error(f"Error retrieving sample data: {e}")
47
+ return None, f"We encountered the following error trying to retrieve a sample image: {e}."
48
+ if filepath is None:
49
+ logger.warning(f"No sample image found for taxon: {pred_taxon}")
50
+ return None, f"Sorry, our EOL images do not include {pred_taxon}."
51
+
52
+ # Get sample image of selected individual
53
+ try:
54
+ img_src = s3_client.generate_presigned_url('get_object',
55
+ Params={'Bucket': 'treeoflife-10m-sample-images',
56
+ 'Key': filepath}
57
+ )
58
+ img_resp = requests.get(img_src)
59
+ img = Image.open(io.BytesIO(img_resp.content))
60
+ full_eol_url = EOL_URL + eol_page_id
61
+ if is_exact:
62
+ eol_page = f"<p>Check out the EOL entry for {pred_taxon} to learn more: <a href={full_eol_url} target='_blank'>{full_eol_url}</a>.</p>"
63
+ else:
64
+ eol_page = f"<p>Check out an example EOL entry within {pred_taxon} to learn more: {full_name} <a href={full_eol_url} target='_blank'>{full_eol_url}</a>.</p>"
65
+ logger.info(f"Successfully retrieved sample image and EOL page for {pred_taxon}")
66
+ return img, eol_page
67
+ except Exception as e:
68
+ logger.error(f"Error retrieving sample image: {e}")
69
+ return None, f"We encountered the following error trying to retrieve a sample image: {e}."
70
+
71
+ def get_sample_data(df, pred_taxon, rank):
72
+ '''
73
+ Function to randomly select a sample individual of the given taxon and provide associated native location.
74
+
75
+ Parameters:
76
+ -----------
77
+ df : DataFrame
78
+ DataFrame with all sample images listed and their filepaths (in "file_path" column).
79
+ pred_taxon : str
80
+ Predicted taxon of the uploaded image.
81
+ rank : int
82
+ Index of rank in RANKS chosen for prediction.
83
+
84
+ Returns:
85
+ --------
86
+ filepath : str
87
+ Filepath of selected sample image for predicted taxon.
88
+ eol_page_id : str
89
+ EOL page ID associated with predicted taxon for more information.
90
+ full_name : str
91
+ Full taxonomic name of the selected sample.
92
+ is_exact : bool
93
+ Flag indicating if the match is exact (i.e., with empty lower ranks).
94
+ '''
95
+ for idx in range(rank + 1):
96
+ taxon = RANKS[idx]
97
+ target_taxon = pred_taxon.split(" ")[idx]
98
+ df = df.filter(pl.col(taxon) == target_taxon)
99
+
100
+ if df.shape[0] == 0:
101
+ return None, np.nan, "", False
102
+
103
+ # First, try to find entries with empty lower ranks
104
+ exact_df = df
105
+ for lower_rank in RANKS[rank + 1:]:
106
+ exact_df = exact_df.filter((pl.col(lower_rank).is_null()) | (pl.col(lower_rank) == ""))
107
+
108
+ if exact_df.shape[0] > 0:
109
+ df_filtered = exact_df.sample()
110
+ full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0))
111
+ return df_filtered["file_path"][0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, True
112
+
113
+ # If no exact matches, return any entry with the specified rank
114
+ df_filtered = df.sample()
115
+ full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0)) + " " + " ".join(df_filtered.select(RANKS[rank+1:]).row(0))
116
+ return df_filtered["file_path"][0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, False
components/sync_samples_to_s3.bash ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ <<COMMENT
4
+ Usage:
5
+ bash sync_samples_to_s3.bash <BASE_DIR>
6
+
7
+ Dependencies:
8
+ - awscli (https://aws.amazon.com/cli/)
9
+ Credentials to export as environment variables:
10
+ - AWS_ACCESS_KEY_ID
11
+ - AWS_SECRET_ACCESS_KEY
12
+ COMMENT
13
+
14
+ # Check if a valid directory is provided as an argument
15
+ if [ -z "$1" ]; then
16
+ echo "Usage: $0 <BASE_DIR>"
17
+ exit 1
18
+ fi
19
+
20
+ if [ ! -d "$1" ]; then
21
+ echo "Error: $1 is not a valid directory"
22
+ exit 1
23
+ fi
24
+
25
+ BASE_DIR="$1"
26
+ S3_BUCKET="s3://treeoflife-10m-sample-images"
27
+
28
+ # Loop through all directories and sync them to S3
29
+ for dir in $BASE_DIR/*; do
30
+ if [ -d "$dir" ]; then
31
+ dir_name=$(basename "$dir")
32
+ aws s3 sync "$dir" "$S3_BUCKET/$dir_name/"
33
+ fi
34
+ done
embed_texts.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #SBATCH --nodes=1
3
+ #SBATCH --account=PAS2136
4
+ #SBATCH --gpus-per-node=1
5
+ #SBATCH --ntasks-per-node=10
6
+ #SBATCH --job-name=embed-treeoflife
7
+ #SBATCH --time=12:00:00
8
+ #SBATCH --partition=gpu
9
+
10
+ python make_txt_embedding.py \
11
+ --catalog-path /fs/ess/PAS2136/open_clip/data/evobio10m-v3.3/predicted-statistics.csv \
12
+ --out-path text_emb.bin
examples/Actinostola-abyssorum.png ADDED

Git LFS Details

  • SHA256: cc56a3aedc6966da7add6093506ba3fc792b6dd2d3178878968c9c6978a4535a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
examples/Amanita-muscaria.jpeg ADDED

Git LFS Details

  • SHA256: c633755d4d45bc8bf86b4f4b889fc3f7acbeaa0e86cc69fce5f25165e21063eb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
examples/Carnegiea-gigantea.png ADDED

Git LFS Details

  • SHA256: 8e55ff224c0b9421b66c2feaf592f20ba473425b79a5e79abca1c8ca8a001e67
  • Pointer size: 131 Bytes
  • Size of remote file: 419 kB
examples/Felis-catus.jpeg ADDED

Git LFS Details

  • SHA256: 4d68c295156ee782524cc9f4269e3111743f7a12441f49c095b975000512829f
  • Pointer size: 131 Bytes
  • Size of remote file: 650 kB
examples/Onoclea-hintonii.jpg ADDED
examples/Onoclea-sensibilis.jpg ADDED
examples/Phoca-vitulina.png ADDED

Git LFS Details

  • SHA256: c717b35bfc07ebc9b9afd041f62bd1744f69e7e40ed9a6eac3a14f11f1ebc7fc
  • Pointer size: 131 Bytes
  • Size of remote file: 455 kB
examples/Sarcoscypha-coccinea.jpeg ADDED

Git LFS Details

  • SHA256: 84dfec1fe373d375cd31f129dfd961dfa9d0b400575f9dd9610a08d900fd1cf9
  • Pointer size: 131 Bytes
  • Size of remote file: 409 kB
examples/Ursus-arctos.jpeg ADDED

Git LFS Details

  • SHA256: b1ead956025e2ef9afa71e352326a299881e575bfb42fae65ae2c157196e2e73
  • Pointer size: 131 Bytes
  • Size of remote file: 610 kB
examples/coral-snake.jpeg ADDED

Git LFS Details

  • SHA256: 871066d1d902bbc5ab9fffa38b2a2d5117bf1b5eacc932188b782cdb6a6eed01
  • Pointer size: 130 Bytes
  • Size of remote file: 51.8 kB
examples/milk-snake.png ADDED

Git LFS Details

  • SHA256: 4c5820dfcdaa056903767cc7a3dade6e9e9d24c686fab9d457889879e80fa3ab
  • Pointer size: 131 Bytes
  • Size of remote file: 411 kB
lib.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mostly a TaxonomicTree class that implements a taxonomy and some helpers for easily
3
+ walking and looking in the tree.
4
+
5
+ A tree is an arrangement of TaxonomicNodes.
6
+
7
+
8
+ """
9
+
10
+
11
+ import itertools
12
+ import json
13
+
14
+
15
+ class TaxonomicNode:
16
+ __slots__ = ("name", "index", "root", "_children")
17
+
18
+ def __init__(self, name, index, root):
19
+ self.name = name
20
+ self.index = index
21
+ self.root = root
22
+ self._children = {}
23
+
24
+ def add(self, name):
25
+ added = 0
26
+ if not name:
27
+ return added
28
+
29
+ first, rest = name[0], name[1:]
30
+ if first not in self._children:
31
+ self._children[first] = TaxonomicNode(first, self.root.size, self.root)
32
+ self.root.size += 1
33
+
34
+ self._children[first].add(rest)
35
+
36
+ def children(self, name):
37
+ if not name:
38
+ return set((child.name, child.index) for child in self._children.values())
39
+
40
+ first, rest = name[0], name[1:]
41
+ if first not in self._children:
42
+ return set()
43
+
44
+ return self._children[first].children(rest)
45
+
46
+ def descendants(self, prefix=None):
47
+ """Iterates over all values in the subtree that match prefix."""
48
+
49
+ if not prefix:
50
+ yield (self.name,), self.index
51
+ for child in self._children.values():
52
+ for name, i in child.descendants():
53
+ yield (self.name, *name), i
54
+ return
55
+
56
+ first, rest = prefix[0], prefix[1:]
57
+ if first not in self._children:
58
+ return
59
+
60
+ for name, i in self._children[first].descendants(rest):
61
+ yield (self.name, *name), i
62
+
63
+ def values(self):
64
+ """Iterates over all (name, i) pairs in the tree."""
65
+ yield (self.name,), self.index
66
+
67
+ for child in self._children.values():
68
+ for name, index in child.values():
69
+ yield (self.name, *name), index
70
+
71
+ @classmethod
72
+ def from_dict(cls, dct, root):
73
+ node = cls(dct["name"], dct["index"], root)
74
+ node._children = {
75
+ child["name"]: cls.from_dict(child, root) for child in dct["children"]
76
+ }
77
+ return node
78
+
79
+
80
+ class TaxonomicTree:
81
+ """
82
+ Efficient structure for finding taxonomic names and their descendants.
83
+ Also returns an integer index i for each possible name.
84
+ """
85
+
86
+ def __init__(self):
87
+ self.kingdoms = {}
88
+ self.size = 0
89
+
90
+ def add(self, name: list[str]):
91
+ if not name:
92
+ return
93
+
94
+ first, rest = name[0], name[1:]
95
+ if first not in self.kingdoms:
96
+ self.kingdoms[first] = TaxonomicNode(first, self.size, self)
97
+ self.size += 1
98
+
99
+ self.kingdoms[first].add(rest)
100
+
101
+ def children(self, name=None):
102
+ if not name:
103
+ return set(
104
+ (kingdom.name, kingdom.index) for kingdom in self.kingdoms.values()
105
+ )
106
+
107
+ first, rest = name[0], name[1:]
108
+ if first not in self.kingdoms:
109
+ return set()
110
+
111
+ return self.kingdoms[first].children(rest)
112
+
113
+ def descendants(self, prefix=None):
114
+ """Iterates over all values in the tree that match prefix."""
115
+ if not prefix:
116
+ # Give them all the subnodes
117
+ for kingdom in self.kingdoms.values():
118
+ yield from kingdom.descendants()
119
+
120
+ return
121
+
122
+ first, rest = prefix[0], prefix[1:]
123
+ if first not in self.kingdoms:
124
+ return
125
+
126
+ yield from self.kingdoms[first].descendants(rest)
127
+
128
+ def values(self):
129
+ """Iterates over all (name, i) pairs in the tree."""
130
+ for kingdom in self.kingdoms.values():
131
+ yield from kingdom.values()
132
+
133
+ def __len__(self):
134
+ return self.size
135
+
136
+ @classmethod
137
+ def from_dict(cls, dct):
138
+ tree = cls()
139
+ tree.kingdoms = {
140
+ kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree)
141
+ for kingdom in dct["kingdoms"]
142
+ }
143
+ tree.size = dct["size"]
144
+ return tree
145
+
146
+
147
+ class TaxonomicJsonEncoder(json.JSONEncoder):
148
+ def default(self, obj):
149
+ if isinstance(obj, TaxonomicNode):
150
+ return {
151
+ "name": obj.name,
152
+ "index": obj.index,
153
+ "children": list(obj._children.values()),
154
+ }
155
+ elif isinstance(obj, TaxonomicTree):
156
+ return {
157
+ "kingdoms": list(obj.kingdoms.values()),
158
+ "size": obj.size,
159
+ }
160
+ else:
161
+ super().default(self, obj)
162
+
163
+
164
+ def batched(iterable, n):
165
+ # batched('ABCDEFG', 3) --> ABC DEF G
166
+ if n < 1:
167
+ raise ValueError("n must be at least one")
168
+ it = iter(iterable)
169
+ while batch := tuple(itertools.islice(it, n)):
170
+ yield zip(*batch)
make_txt_embedding.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Makes the entire set of text emebeddings for all possible names in the tree of life.
3
+ Uses the catalog.csv file from TreeOfLife-10M.
4
+ """
5
+ import argparse
6
+ import csv
7
+ import json
8
+ import os
9
+ import logging
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from open_clip import create_model, get_tokenizer
16
+ from tqdm import tqdm
17
+
18
+ import lib
19
+ from templates import openai_imagenet_template
20
+
21
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
22
+ logging.basicConfig(level=logging.INFO, format=log_format)
23
+ logger = logging.getLogger()
24
+
25
+ model_str = "hf-hub:imageomics/bioclip"
26
+ tokenizer_str = "ViT-B-16"
27
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
+
29
+ ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
30
+
31
+
32
+ @torch.no_grad()
33
+ def write_txt_features(name_lookup):
34
+ if os.path.isfile(args.out_path):
35
+ all_features = np.load(args.out_path)
36
+ else:
37
+ all_features = np.zeros((512, len(name_lookup)), dtype=np.float32)
38
+
39
+ batch_size = args.batch_size // len(openai_imagenet_template)
40
+ for batch, (names, indices) in enumerate(
41
+ tqdm(
42
+ lib.batched(name_lookup.values(), batch_size),
43
+ desc="txt feats",
44
+ total=len(name_lookup) // batch_size,
45
+ )
46
+ ):
47
+ # Skip if any non-zero elements
48
+ if all_features[:, indices].any():
49
+ logger.info(f"Skipping batch {batch}")
50
+ continue
51
+
52
+ txts = [
53
+ template(name) for name in names for template in openai_imagenet_template
54
+ ]
55
+ txts = tokenizer(txts).to(device)
56
+ txt_features = model.encode_text(txts)
57
+ txt_features = torch.reshape(
58
+ txt_features, (len(names), len(openai_imagenet_template), 512)
59
+ )
60
+ txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
61
+ txt_features /= txt_features.norm(dim=1, keepdim=True)
62
+ all_features[:, indices] = txt_features.T.cpu().numpy()
63
+
64
+ if batch % 100 == 0:
65
+ np.save(args.out_path, all_features)
66
+
67
+ np.save(args.out_path, all_features)
68
+
69
+
70
+ def convert_txt_features_to_avgs(name_lookup):
71
+ assert os.path.isfile(args.out_path)
72
+
73
+ # Put that big boy on the GPU. We're going fast.
74
+ all_features = torch.from_numpy(np.load(args.out_path)).to(device)
75
+ logger.info("Loaded text features from disk to %s.", device)
76
+
77
+ names_by_rank = [set() for rank in ranks]
78
+ for name, index in tqdm(name_lookup.values()):
79
+ i = len(name) - 1
80
+ names_by_rank[i].add((name, index))
81
+
82
+ zeroed = 0
83
+ for i, rank in reversed(list(enumerate(ranks))):
84
+ if rank == "Species":
85
+ continue
86
+ for name, index in tqdm(names_by_rank[i], desc=rank):
87
+ species = tuple(
88
+ zip(
89
+ *(
90
+ (d, i)
91
+ for d, i in name_lookup.descendants(prefix=name)
92
+ if len(d) >= 6
93
+ )
94
+ )
95
+ )
96
+ if not species:
97
+ logger.warning("No species for %s.", " ".join(name))
98
+ all_features[:, index] = 0.0
99
+ zeroed += 1
100
+ continue
101
+
102
+ values, indices = species
103
+ mean = all_features[:, indices].mean(dim=1)
104
+ all_features[:, index] = F.normalize(mean, dim=0)
105
+
106
+ out_path, ext = os.path.splitext(args.out_path)
107
+ np.save(f"{out_path}_avgs{ext}", all_features.cpu().numpy())
108
+ if zeroed:
109
+ logger.warning(
110
+ "Zeroed out %d nodes because they didn't have any genus or species-level labels.",
111
+ zeroed,
112
+ )
113
+
114
+
115
+ def convert_txt_features_to_species_only(name_lookup):
116
+ assert os.path.isfile(args.out_path)
117
+
118
+ all_features = np.load(args.out_path)
119
+ logger.info("Loaded text features from disk.")
120
+
121
+ species = [(d, i) for d, i in name_lookup.descendants() if len(d) == 7]
122
+ species_features = np.zeros((512, len(species)), dtype=np.float32)
123
+ species_names = [""] * len(species)
124
+
125
+ for new_i, (name, old_i) in enumerate(tqdm(species)):
126
+ species_features[:, new_i] = all_features[:, old_i]
127
+ species_names[new_i] = name
128
+
129
+ out_path, ext = os.path.splitext(args.out_path)
130
+ np.save(f"{out_path}_species{ext}", species_features)
131
+ with open(f"{out_path}_species.json", "w") as fd:
132
+ json.dump(species_names, fd, indent=2)
133
+
134
+
135
+ def get_name_lookup(catalog_path, cache_path):
136
+ if os.path.isfile(cache_path):
137
+ with open(cache_path) as fd:
138
+ lookup = lib.TaxonomicTree.from_dict(json.load(fd))
139
+ return lookup
140
+
141
+ lookup = lib.TaxonomicTree()
142
+
143
+ with open(catalog_path) as fd:
144
+ reader = csv.DictReader(fd)
145
+ for row in tqdm(reader, desc="catalog"):
146
+ name = [
147
+ row["kingdom"],
148
+ row["phylum"],
149
+ row["class"],
150
+ row["order"],
151
+ row["family"],
152
+ row["genus"],
153
+ row["species"],
154
+ ]
155
+ if any(not value for value in name):
156
+ name = name[: name.index("")]
157
+ lookup.add(name)
158
+
159
+ with open(args.name_cache_path, "w") as fd:
160
+ json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder)
161
+
162
+ return lookup
163
+
164
+
165
+ if __name__ == "__main__":
166
+ parser = argparse.ArgumentParser()
167
+ parser.add_argument(
168
+ "--catalog-path",
169
+ help="Path to the catalog.csv file from TreeOfLife-10M.",
170
+ required=True,
171
+ )
172
+ parser.add_argument("--out-path", help="Path to the output file.", required=True)
173
+ parser.add_argument(
174
+ "--name-cache-path",
175
+ help="Path to the name cache file.",
176
+ default="name_lookup.json",
177
+ )
178
+ parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int)
179
+ args = parser.parse_args()
180
+
181
+ name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
182
+ logger.info("Got name lookup.")
183
+
184
+ model = create_model(model_str, output_dict=True, require_pretrained=True)
185
+ model = model.to(device)
186
+ logger.info("Created model.")
187
+ model = torch.compile(model)
188
+ logger.info("Compiled model.")
189
+
190
+ tokenizer = get_tokenizer(tokenizer_str)
191
+ write_txt_features(name_lookup)
192
+ convert_txt_features_to_avgs(name_lookup)
193
+ convert_txt_features_to_species_only(name_lookup)
name_lookup.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20d731d9d901f1c17927187bc87e4a2513279845a1a6ba5982dbf779f2ac1434
3
+ size 26462858
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ open_clip_torch
2
+ torchvision
3
+ torch
4
+ gradio
5
+ polars
6
+ pillow
7
+ boto3
templates.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai_imagenet_template = [
2
+ lambda c: f"a bad photo of a {c}.",
3
+ lambda c: f"a photo of many {c}.",
4
+ lambda c: f"a sculpture of a {c}.",
5
+ lambda c: f"a photo of the hard to see {c}.",
6
+ lambda c: f"a low resolution photo of the {c}.",
7
+ lambda c: f"a rendering of a {c}.",
8
+ lambda c: f"graffiti of a {c}.",
9
+ lambda c: f"a bad photo of the {c}.",
10
+ lambda c: f"a cropped photo of the {c}.",
11
+ lambda c: f"a tattoo of a {c}.",
12
+ lambda c: f"the embroidered {c}.",
13
+ lambda c: f"a photo of a hard to see {c}.",
14
+ lambda c: f"a bright photo of a {c}.",
15
+ lambda c: f"a photo of a clean {c}.",
16
+ lambda c: f"a photo of a dirty {c}.",
17
+ lambda c: f"a dark photo of the {c}.",
18
+ lambda c: f"a drawing of a {c}.",
19
+ lambda c: f"a photo of my {c}.",
20
+ lambda c: f"the plastic {c}.",
21
+ lambda c: f"a photo of the cool {c}.",
22
+ lambda c: f"a close-up photo of a {c}.",
23
+ lambda c: f"a black and white photo of the {c}.",
24
+ lambda c: f"a painting of the {c}.",
25
+ lambda c: f"a painting of a {c}.",
26
+ lambda c: f"a pixelated photo of the {c}.",
27
+ lambda c: f"a sculpture of the {c}.",
28
+ lambda c: f"a bright photo of the {c}.",
29
+ lambda c: f"a cropped photo of a {c}.",
30
+ lambda c: f"a plastic {c}.",
31
+ lambda c: f"a photo of the dirty {c}.",
32
+ lambda c: f"a jpeg corrupted photo of a {c}.",
33
+ lambda c: f"a blurry photo of the {c}.",
34
+ lambda c: f"a photo of the {c}.",
35
+ lambda c: f"a good photo of the {c}.",
36
+ lambda c: f"a rendering of the {c}.",
37
+ lambda c: f"a {c} in a video game.",
38
+ lambda c: f"a photo of one {c}.",
39
+ lambda c: f"a doodle of a {c}.",
40
+ lambda c: f"a close-up photo of the {c}.",
41
+ lambda c: f"a photo of a {c}.",
42
+ lambda c: f"the origami {c}.",
43
+ lambda c: f"the {c} in a video game.",
44
+ lambda c: f"a sketch of a {c}.",
45
+ lambda c: f"a doodle of the {c}.",
46
+ lambda c: f"a origami {c}.",
47
+ lambda c: f"a low resolution photo of a {c}.",
48
+ lambda c: f"the toy {c}.",
49
+ lambda c: f"a rendition of the {c}.",
50
+ lambda c: f"a photo of the clean {c}.",
51
+ lambda c: f"a photo of a large {c}.",
52
+ lambda c: f"a rendition of a {c}.",
53
+ lambda c: f"a photo of a nice {c}.",
54
+ lambda c: f"a photo of a weird {c}.",
55
+ lambda c: f"a blurry photo of a {c}.",
56
+ lambda c: f"a cartoon {c}.",
57
+ lambda c: f"art of a {c}.",
58
+ lambda c: f"a sketch of the {c}.",
59
+ lambda c: f"a embroidered {c}.",
60
+ lambda c: f"a pixelated photo of a {c}.",
61
+ lambda c: f"itap of the {c}.",
62
+ lambda c: f"a jpeg corrupted photo of the {c}.",
63
+ lambda c: f"a good photo of a {c}.",
64
+ lambda c: f"a plushie {c}.",
65
+ lambda c: f"a photo of the nice {c}.",
66
+ lambda c: f"a photo of the small {c}.",
67
+ lambda c: f"a photo of the weird {c}.",
68
+ lambda c: f"the cartoon {c}.",
69
+ lambda c: f"art of the {c}.",
70
+ lambda c: f"a drawing of the {c}.",
71
+ lambda c: f"a photo of the large {c}.",
72
+ lambda c: f"a black and white photo of a {c}.",
73
+ lambda c: f"the plushie {c}.",
74
+ lambda c: f"a dark photo of a {c}.",
75
+ lambda c: f"itap of a {c}.",
76
+ lambda c: f"graffiti of the {c}.",
77
+ lambda c: f"a toy {c}.",
78
+ lambda c: f"itap of my {c}.",
79
+ lambda c: f"a photo of a cool {c}.",
80
+ lambda c: f"a photo of a small {c}.",
81
+ lambda c: f"a tattoo of the {c}.",
82
+ ]
test_lib.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lib
2
+
3
+
4
+ def test_taxonomiclookup_empty():
5
+ lookup = lib.TaxonomicTree()
6
+ assert lookup.size == 0
7
+
8
+
9
+ def test_taxonomiclookup_kingdom_size():
10
+ lookup = lib.TaxonomicTree()
11
+
12
+ lookup.add(("Animalia",))
13
+
14
+ assert lookup.size == 1
15
+
16
+
17
+ def test_taxonomiclookup_genus_size():
18
+ lookup = lib.TaxonomicTree()
19
+
20
+ lookup.add(
21
+ (
22
+ "Animalia",
23
+ "Chordata",
24
+ "Aves",
25
+ "Accipitriformes",
26
+ "Accipitridae",
27
+ "Halieaeetus",
28
+ )
29
+ )
30
+
31
+ assert lookup.size == 6
32
+
33
+
34
+ def test_taxonomictree_kingdom_children():
35
+ lookup = lib.TaxonomicTree()
36
+
37
+ lookup.add(
38
+ (
39
+ "Animalia",
40
+ "Chordata",
41
+ "Aves",
42
+ "Accipitriformes",
43
+ "Accipitridae",
44
+ "Halieaeetus",
45
+ )
46
+ )
47
+
48
+ expected = set([("Animalia", 0)])
49
+ actual = lookup.children()
50
+ assert actual == expected
51
+
52
+
53
+ def test_taxonomiclookup_children_of_animal_only_birds():
54
+ lookup = lib.TaxonomicTree()
55
+
56
+ lookup.add(
57
+ (
58
+ "Animalia",
59
+ "Chordata",
60
+ "Aves",
61
+ "Accipitriformes",
62
+ "Accipitridae",
63
+ "Halieaeetus",
64
+ "leucocephalus",
65
+ )
66
+ )
67
+ lookup.add(
68
+ (
69
+ "Animalia",
70
+ "Chordata",
71
+ "Aves",
72
+ "Strigiformes",
73
+ "Strigidae",
74
+ "Ninox",
75
+ "scutulata",
76
+ )
77
+ )
78
+ lookup.add(
79
+ (
80
+ "Animalia",
81
+ "Chordata",
82
+ "Aves",
83
+ "Strigiformes",
84
+ "Strigidae",
85
+ "Ninox",
86
+ "plesseni",
87
+ )
88
+ )
89
+
90
+ actual = lookup.children(("Animalia",))
91
+ expected = set([("Chordata", 1)])
92
+ assert actual == expected
93
+
94
+
95
+ def test_taxonomiclookup_children_of_animal():
96
+ lookup = lib.TaxonomicTree()
97
+
98
+ lookup.add(
99
+ (
100
+ "Animalia",
101
+ "Chordata",
102
+ "Aves",
103
+ "Accipitriformes",
104
+ "Accipitridae",
105
+ "Halieaeetus",
106
+ "leucocephalus",
107
+ )
108
+ )
109
+ lookup.add(
110
+ (
111
+ "Animalia",
112
+ "Chordata",
113
+ "Aves",
114
+ "Strigiformes",
115
+ "Strigidae",
116
+ "Ninox",
117
+ "scutulata",
118
+ )
119
+ )
120
+ lookup.add(
121
+ (
122
+ "Animalia",
123
+ "Chordata",
124
+ "Aves",
125
+ "Strigiformes",
126
+ "Strigidae",
127
+ "Ninox",
128
+ "plesseni",
129
+ )
130
+ )
131
+ lookup.add(
132
+ (
133
+ "Animalia",
134
+ "Chordata",
135
+ "Mammalia",
136
+ "Primates",
137
+ "Hominidae",
138
+ "Gorilla",
139
+ "gorilla",
140
+ )
141
+ )
142
+ lookup.add(
143
+ (
144
+ "Animalia",
145
+ "Arthropoda",
146
+ "Insecta",
147
+ "Hymenoptera",
148
+ "Apidae",
149
+ "Bombus",
150
+ "balteatus",
151
+ )
152
+ )
153
+
154
+ actual = lookup.children(("Animalia",))
155
+ expected = set([("Chordata", 1), ("Arthropoda", 17)])
156
+ assert actual == expected
157
+
158
+
159
+ def test_taxonomiclookup_children_of_chordata():
160
+ lookup = lib.TaxonomicTree()
161
+
162
+ lookup.add(
163
+ (
164
+ "Animalia",
165
+ "Chordata",
166
+ "Aves",
167
+ "Accipitriformes",
168
+ "Accipitridae",
169
+ "Halieaeetus",
170
+ "leucocephalus",
171
+ )
172
+ )
173
+ lookup.add(
174
+ (
175
+ "Animalia",
176
+ "Chordata",
177
+ "Aves",
178
+ "Strigiformes",
179
+ "Strigidae",
180
+ "Ninox",
181
+ "scutulata",
182
+ )
183
+ )
184
+ lookup.add(
185
+ (
186
+ "Animalia",
187
+ "Chordata",
188
+ "Aves",
189
+ "Strigiformes",
190
+ "Strigidae",
191
+ "Ninox",
192
+ "plesseni",
193
+ )
194
+ )
195
+ lookup.add(
196
+ (
197
+ "Animalia",
198
+ "Chordata",
199
+ "Mammalia",
200
+ "Primates",
201
+ "Hominidae",
202
+ "Gorilla",
203
+ "gorilla",
204
+ )
205
+ )
206
+ lookup.add(
207
+ (
208
+ "Animalia",
209
+ "Arthropoda",
210
+ "Insecta",
211
+ "Hymenoptera",
212
+ "Apidae",
213
+ "Bombus",
214
+ "balteatus",
215
+ )
216
+ )
217
+
218
+ actual = lookup.children(("Animalia", "Chordata"))
219
+ expected = set([("Aves", 2), ("Mammalia", 12)])
220
+ assert actual == expected
221
+
222
+
223
+ def test_taxonomiclookup_children_of_strigiformes():
224
+ lookup = lib.TaxonomicTree()
225
+
226
+ lookup.add(
227
+ (
228
+ "Animalia",
229
+ "Chordata",
230
+ "Aves",
231
+ "Accipitriformes",
232
+ "Accipitridae",
233
+ "Halieaeetus",
234
+ "leucocephalus",
235
+ )
236
+ )
237
+ lookup.add(
238
+ (
239
+ "Animalia",
240
+ "Chordata",
241
+ "Aves",
242
+ "Strigiformes",
243
+ "Strigidae",
244
+ "Ninox",
245
+ "scutulata",
246
+ )
247
+ )
248
+ lookup.add(
249
+ (
250
+ "Animalia",
251
+ "Chordata",
252
+ "Aves",
253
+ "Strigiformes",
254
+ "Strigidae",
255
+ "Ninox",
256
+ "plesseni",
257
+ )
258
+ )
259
+ lookup.add(
260
+ (
261
+ "Animalia",
262
+ "Chordata",
263
+ "Mammalia",
264
+ "Primates",
265
+ "Hominidae",
266
+ "Gorilla",
267
+ "gorilla",
268
+ )
269
+ )
270
+ lookup.add(
271
+ (
272
+ "Animalia",
273
+ "Arthropoda",
274
+ "Insecta",
275
+ "Hymenoptera",
276
+ "Apidae",
277
+ "Bombus",
278
+ "balteatus",
279
+ )
280
+ )
281
+
282
+ actual = lookup.children(("Animalia", "Chordata", "Aves", "Strigiformes"))
283
+ expected = set([("Strigidae", 8)])
284
+ assert actual == expected
285
+
286
+
287
+ def test_taxonomiclookup_children_of_ninox():
288
+ lookup = lib.TaxonomicTree()
289
+
290
+ lookup.add(
291
+ (
292
+ "Animalia",
293
+ "Chordata",
294
+ "Aves",
295
+ "Accipitriformes",
296
+ "Accipitridae",
297
+ "Halieaeetus",
298
+ "leucocephalus",
299
+ )
300
+ )
301
+ lookup.add(
302
+ (
303
+ "Animalia",
304
+ "Chordata",
305
+ "Aves",
306
+ "Strigiformes",
307
+ "Strigidae",
308
+ "Ninox",
309
+ "scutulata",
310
+ )
311
+ )
312
+ lookup.add(
313
+ (
314
+ "Animalia",
315
+ "Chordata",
316
+ "Aves",
317
+ "Strigiformes",
318
+ "Strigidae",
319
+ "Ninox",
320
+ "plesseni",
321
+ )
322
+ )
323
+ lookup.add(
324
+ (
325
+ "Animalia",
326
+ "Chordata",
327
+ "Mammalia",
328
+ "Primates",
329
+ "Hominidae",
330
+ "Gorilla",
331
+ "gorilla",
332
+ )
333
+ )
334
+ lookup.add(
335
+ (
336
+ "Animalia",
337
+ "Arthropoda",
338
+ "Insecta",
339
+ "Hymenoptera",
340
+ "Apidae",
341
+ "Bombus",
342
+ "balteatus",
343
+ )
344
+ )
345
+
346
+ actual = lookup.children(
347
+ ("Animalia", "Chordata", "Aves", "Strigiformes", "Strigidae", "Ninox")
348
+ )
349
+ expected = set([("scutulata", 10), ("plesseni", 11)])
350
+ assert actual == expected
351
+
352
+
353
+ def test_taxonomiclookup_children_of_gorilla():
354
+ lookup = lib.TaxonomicTree()
355
+
356
+ lookup.add(
357
+ (
358
+ "Animalia",
359
+ "Chordata",
360
+ "Aves",
361
+ "Accipitriformes",
362
+ "Accipitridae",
363
+ "Halieaeetus",
364
+ "leucocephalus",
365
+ )
366
+ )
367
+ lookup.add(
368
+ (
369
+ "Animalia",
370
+ "Chordata",
371
+ "Aves",
372
+ "Strigiformes",
373
+ "Strigidae",
374
+ "Ninox",
375
+ "scutulata",
376
+ )
377
+ )
378
+ lookup.add(
379
+ (
380
+ "Animalia",
381
+ "Chordata",
382
+ "Aves",
383
+ "Strigiformes",
384
+ "Strigidae",
385
+ "Ninox",
386
+ "plesseni",
387
+ )
388
+ )
389
+ lookup.add(
390
+ (
391
+ "Animalia",
392
+ "Chordata",
393
+ "Mammalia",
394
+ "Primates",
395
+ "Hominidae",
396
+ "Gorilla",
397
+ "gorilla",
398
+ )
399
+ )
400
+ lookup.add(
401
+ (
402
+ "Animalia",
403
+ "Arthropoda",
404
+ "Insecta",
405
+ "Hymenoptera",
406
+ "Apidae",
407
+ "Bombus",
408
+ "balteatus",
409
+ )
410
+ )
411
+
412
+ actual = lookup.children(
413
+ (
414
+ "Animalia",
415
+ "Chordata",
416
+ "Mammalia",
417
+ "Primates",
418
+ "Hominidae",
419
+ "Gorilla",
420
+ "gorilla",
421
+ )
422
+ )
423
+ expected = set()
424
+ assert actual == expected
425
+
426
+
427
+ def test_taxonomictree_descendants_last():
428
+ lookup = lib.TaxonomicTree()
429
+
430
+ lookup.add(("A", "B", "C", "D", "E", "F", "G"))
431
+
432
+ actual = list(lookup.descendants(("A", "B", "C", "D", "E", "F", "G")))
433
+
434
+ expected = [
435
+ (("A", "B", "C", "D", "E", "F", "G"), 6),
436
+ ]
437
+ assert actual == expected
438
+
439
+
440
+ def test_taxonomictree_descendants_entire_tree():
441
+ lookup = lib.TaxonomicTree()
442
+
443
+ lookup.add(("A", "B"))
444
+
445
+ actual = list(lookup.descendants())
446
+
447
+ expected = [
448
+ (("A",), 0),
449
+ (("A", "B"), 1),
450
+ ]
451
+ assert actual == expected
452
+
453
+
454
+ def test_taxonomictree_descendants_entire_tree_with_prefix():
455
+ lookup = lib.TaxonomicTree()
456
+
457
+ lookup.add(("A", "B"))
458
+
459
+ actual = list(lookup.descendants(prefix=("A",)))
460
+
461
+ expected = [
462
+ (("A",), 0),
463
+ (("A", "B"), 1),
464
+ ]
465
+ assert actual == expected
466
+
467
+
468
+ def test_taxonomictree_descendants_general():
469
+ lookup = lib.TaxonomicTree()
470
+
471
+ lookup.add(("A", "B", "C", "D", "E", "F", "G"))
472
+
473
+ actual = list(lookup.descendants(("A", "B", "C", "D")))
474
+
475
+ expected = [
476
+ (("A", "B", "C", "D"), 3),
477
+ (("A", "B", "C", "D", "E"), 4),
478
+ (("A", "B", "C", "D", "E", "F"), 5),
479
+ (("A", "B", "C", "D", "E", "F", "G"), 6),
480
+ ]
481
+ assert actual == expected
txt_emb.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4a3c3412c3dae49cf92cc760aba5ee84227362adf1eb08f04dd50ee2a756e43
3
+ size 969818240
txt_emb_species.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:844e6fabc06cac072214d566b78f40825b154efa9479eb11285030ca038b2ece
3
+ size 65731052
txt_emb_species.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91ce02dff2433222e3138b8bf7eefa1dd74b30f4d406c16cd3301f66d65ab4ed
3
+ size 787435648