Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| # flake8: noqa | |
| import torch | |
| import torch.nn.functional as F | |
| def get_r_precision(image_text_code, eps=1e-5): | |
| all_image_code, all_text_code = torch.chunk(image_text_code, 2, dim=1) | |
| P_rates = [] | |
| num_samples = len(all_image_code) | |
| assert num_samples >= 100 | |
| for i in range(0, num_samples, 100): | |
| if i + 100 <= num_samples: | |
| cur_image_code = all_image_code[i:i + 100] | |
| cur_text_code = all_text_code[i:i + 100] | |
| cur_image_code = F.normalize(cur_image_code, dim=1, eps=eps) | |
| cur_text_code = F.normalize(cur_text_code, dim=1, eps=eps) | |
| cosine_similarities = cur_image_code @ cur_text_code.T | |
| top1_indices = torch.topk(cosine_similarities, dim=1, k=1)[1][:, 0] | |
| P_rate = torch.sum(top1_indices == torch.arange(100, device=top1_indices.device)).item() | |
| P_rates.append(P_rate) | |
| A_precision = sum(P_rates) * 1.0 / len(P_rates) | |
| return {"caption_rprec": A_precision} | |