Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict, Optional | |
| import torch | |
| from shap_e.models.query import Query | |
| from shap_e.models.renderer import append_tensor | |
| from shap_e.util.collections import AttrDict | |
| class Model(ABC): | |
| def forward( | |
| self, | |
| query: Query, | |
| params: Optional[Dict[str, torch.Tensor]] = None, | |
| options: Optional[Dict[str, Any]] = None, | |
| ) -> AttrDict[str, Any]: | |
| """ | |
| Predict an attribute given position | |
| """ | |
| def forward_batched( | |
| self, | |
| query: Query, | |
| query_batch_size: int = 4096, | |
| params: Optional[Dict[str, torch.Tensor]] = None, | |
| options: Optional[Dict[str, Any]] = None, | |
| ) -> AttrDict[str, Any]: | |
| if not query.position.numel(): | |
| # Avoid torch.cat() of zero tensors. | |
| return self(query, params=params, options=options) | |
| if options.cache is None: | |
| created_cache = True | |
| options.cache = AttrDict() | |
| else: | |
| created_cache = False | |
| results_list = AttrDict() | |
| for i in range(0, query.position.shape[1], query_batch_size): | |
| out = self( | |
| query=query.map_tensors(lambda x, i=i: x[:, i : i + query_batch_size]), | |
| params=params, | |
| options=options, | |
| ) | |
| results_list = results_list.combine(out, append_tensor) | |
| if created_cache: | |
| del options["cache"] | |
| return results_list.map(lambda key, tensor_list: torch.cat(tensor_list, dim=1)) | |