Spaces:
Runtime error
Runtime error
| """ | |
| Makes the entire set of text emebeddings for all possible names in the tree of life. | |
| Uses the catalog.csv file from TreeOfLife-10M. | |
| """ | |
| import argparse | |
| import csv | |
| import json | |
| import os | |
| import logging | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from open_clip import create_model, get_tokenizer | |
| from tqdm import tqdm | |
| import lib | |
| from templates import openai_imagenet_template | |
| log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" | |
| logging.basicConfig(level=logging.INFO, format=log_format) | |
| logger = logging.getLogger() | |
| model_str = "hf-hub:imageomics/bioclip" | |
| tokenizer_str = "ViT-B-16" | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species") | |
| def write_txt_features(name_lookup): | |
| if os.path.isfile(args.out_path): | |
| all_features = np.load(args.out_path) | |
| else: | |
| all_features = np.zeros((512, len(name_lookup)), dtype=np.float32) | |
| batch_size = args.batch_size // len(openai_imagenet_template) | |
| for batch, (names, indices) in enumerate( | |
| tqdm( | |
| lib.batched(name_lookup.values(), batch_size), | |
| desc="txt feats", | |
| total=len(name_lookup) // batch_size, | |
| ) | |
| ): | |
| # Skip if any non-zero elements | |
| if all_features[:, indices].any(): | |
| logger.info(f"Skipping batch {batch}") | |
| continue | |
| txts = [ | |
| template(name) for name in names for template in openai_imagenet_template | |
| ] | |
| txts = tokenizer(txts).to(device) | |
| txt_features = model.encode_text(txts) | |
| txt_features = torch.reshape( | |
| txt_features, (len(names), len(openai_imagenet_template), 512) | |
| ) | |
| txt_features = F.normalize(txt_features, dim=2).mean(dim=1) | |
| txt_features /= txt_features.norm(dim=1, keepdim=True) | |
| all_features[:, indices] = txt_features.T.cpu().numpy() | |
| if batch % 100 == 0: | |
| np.save(args.out_path, all_features) | |
| np.save(args.out_path, all_features) | |
| def convert_txt_features_to_avgs(name_lookup): | |
| assert os.path.isfile(args.out_path) | |
| # Put that big boy on the GPU. We're going fast. | |
| all_features = torch.from_numpy(np.load(args.out_path)).to(device) | |
| logger.info("Loaded text features from disk to %s.", device) | |
| names_by_rank = [set() for rank in ranks] | |
| for name, index in tqdm(name_lookup.values()): | |
| i = len(name) - 1 | |
| names_by_rank[i].add((name, index)) | |
| zeroed = 0 | |
| for i, rank in reversed(list(enumerate(ranks))): | |
| if rank == "Species": | |
| continue | |
| for name, index in tqdm(names_by_rank[i], desc=rank): | |
| species = tuple( | |
| zip( | |
| *( | |
| (d, i) | |
| for d, i in name_lookup.descendants(prefix=name) | |
| if len(d) >= 6 | |
| ) | |
| ) | |
| ) | |
| if not species: | |
| logger.warning("No species for %s.", " ".join(name)) | |
| all_features[:, index] = 0.0 | |
| zeroed += 1 | |
| continue | |
| values, indices = species | |
| mean = all_features[:, indices].mean(dim=1) | |
| all_features[:, index] = F.normalize(mean, dim=0) | |
| out_path, ext = os.path.splitext(args.out_path) | |
| np.save(f"{out_path}_avgs{ext}", all_features.cpu().numpy()) | |
| if zeroed: | |
| logger.warning( | |
| "Zeroed out %d nodes because they didn't have any genus or species-level labels.", | |
| zeroed, | |
| ) | |
| def convert_txt_features_to_species_only(name_lookup): | |
| assert os.path.isfile(args.out_path) | |
| all_features = np.load(args.out_path) | |
| logger.info("Loaded text features from disk.") | |
| species = [(d, i) for d, i in name_lookup.descendants() if len(d) == 7] | |
| species_features = np.zeros((512, len(species)), dtype=np.float32) | |
| species_names = [""] * len(species) | |
| for new_i, (name, old_i) in enumerate(tqdm(species)): | |
| species_features[:, new_i] = all_features[:, old_i] | |
| species_names[new_i] = name | |
| out_path, ext = os.path.splitext(args.out_path) | |
| np.save(f"{out_path}_species{ext}", species_features) | |
| with open(f"{out_path}_species.json", "w") as fd: | |
| json.dump(species_names, fd, indent=2) | |
| def get_name_lookup(catalog_path, cache_path): | |
| if os.path.isfile(cache_path): | |
| with open(cache_path) as fd: | |
| lookup = lib.TaxonomicTree.from_dict(json.load(fd)) | |
| return lookup | |
| lookup = lib.TaxonomicTree() | |
| with open(catalog_path) as fd: | |
| reader = csv.DictReader(fd) | |
| for row in tqdm(reader, desc="catalog"): | |
| name = [ | |
| row["kingdom"], | |
| row["phylum"], | |
| row["class"], | |
| row["order"], | |
| row["family"], | |
| row["genus"], | |
| row["species"], | |
| ] | |
| if any(not value for value in name): | |
| name = name[: name.index("")] | |
| lookup.add(name) | |
| with open(args.name_cache_path, "w") as fd: | |
| json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder) | |
| return lookup | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--catalog-path", | |
| help="Path to the catalog.csv file from TreeOfLife-10M.", | |
| required=True, | |
| ) | |
| parser.add_argument("--out-path", help="Path to the output file.", required=True) | |
| parser.add_argument( | |
| "--name-cache-path", | |
| help="Path to the name cache file.", | |
| default="name_lookup.json", | |
| ) | |
| parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int) | |
| args = parser.parse_args() | |
| name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path) | |
| logger.info("Got name lookup.") | |
| model = create_model(model_str, output_dict=True, require_pretrained=True) | |
| model = model.to(device) | |
| logger.info("Created model.") | |
| model = torch.compile(model) | |
| logger.info("Compiled model.") | |
| tokenizer = get_tokenizer(tokenizer_str) | |
| write_txt_features(name_lookup) | |
| convert_txt_features_to_avgs(name_lookup) | |
| convert_txt_features_to_species_only(name_lookup) | |