Spaces:
Sleeping
Sleeping
| import torch | |
| def extract_text_feature(prompt, model, processor, device='cpu'): | |
| """Extract text features | |
| Args: | |
| prompt: a single text query | |
| model: OwlViT model | |
| processor: OwlViT processor | |
| device (str, optional): device to run. Defaults to 'cpu'. | |
| """ | |
| device = 'cpu' | |
| if torch.cuda.is_available(): | |
| device = 'cuda' | |
| with torch.no_grad(): | |
| input_ids = torch.as_tensor(processor(text=prompt)[ | |
| 'input_ids']).to(device) | |
| print(input_ids.device) | |
| text_outputs = model.owlvit.text_model( | |
| input_ids=input_ids, | |
| attention_mask=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ) | |
| text_embeds = text_outputs[1] | |
| text_embeds = model.owlvit.text_projection(text_embeds) | |
| text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6 | |
| query_embeds = text_embeds | |
| return input_ids, query_embeds | |
| def prompt2vec(prompt: str, model, processor): | |
| """ Convert prompt into a computational vector | |
| Args: | |
| prompt (str): Text to be tokenized | |
| Returns: | |
| xq: vector from the tokenizer, representing the original prompt | |
| """ | |
| # inputs = tokenizer(prompt, return_tensors='pt') | |
| # out = clip.get_text_features(**inputs) | |
| input_ids, xq = extract_text_feature(prompt, model, processor) | |
| input_ids = input_ids.detach().cpu().numpy() | |
| xq = xq.detach().cpu().numpy() | |
| return input_ids, xq | |
| def tune(clf, X, y, iters=2): | |
| """ Train the Zero-shot Classifier | |
| Args: | |
| X (numpy.ndarray): Input vectors (retreived vectors) | |
| y (list of floats or numpy.ndarray): Scores given by user | |
| iters (int, optional): iterations of updates to be run | |
| """ | |
| assert len(X) == len(y) | |
| # train the classifier | |
| clf.fit(X, y, iters=iters) | |
| # extract new vector | |
| return clf.get_weights() | |
| class Classifier: | |
| """Multi-Class Zero-shot Classifier | |
| This Classifier provides proxy regarding to the user's reaction to the probed images. | |
| The proxy will replace the original query vector generated by prompted vector and finally | |
| give the user a satisfying retrieval result. | |
| This can be commonly seen in a recommendation system. The classifier will recommend more | |
| precise result as it accumulating user's activity. | |
| This is a multiclass classifier. For N queries it will set the all queries to the first-N classes | |
| and the last one takes the negative one. | |
| """ | |
| def __init__(self, xq: list): | |
| init_weight = torch.Tensor(xq) | |
| self.num_class = xq.shape[0] | |
| DIMS = xq.shape[1] | |
| # note that the bias is ignored, as we only focus on the inner product result | |
| self.model = torch.nn.Linear(DIMS, self.num_class, bias=False) | |
| # convert initial query `xq` to tensor parameter to init weights | |
| self.model.weight = torch.nn.Parameter(init_weight) | |
| # init loss and optimizer | |
| self.loss = torch.nn.BCEWithLogitsLoss() | |
| self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) | |
| def fit(self, X: list, y: list, iters: int = 5): | |
| # convert X and y to tensor | |
| X = torch.Tensor(X) | |
| X /= torch.norm(X, p=2, dim=-1, keepdim=True) | |
| y = torch.Tensor(y).long() | |
| # Generate labels for binary classification and ignore outbound labels | |
| non_ind = y > self.num_class | |
| y = torch.nn.functional.one_hot(y % self.num_class, num_classes=self.num_class).float() | |
| y[non_ind] = 0 | |
| for i in range(iters): | |
| # zero gradients | |
| self.optimizer.zero_grad() | |
| # Normalize the weight before inference | |
| # This will constrain the gradient or you will have an explosion on query vector | |
| self.model.weight.data /= torch.norm(self.model.weight.data, p=2, dim=-1, keepdim=True) | |
| # forward pass | |
| out = self.model(X) | |
| # compute loss | |
| loss = self.loss(out, y) | |
| # backward pass | |
| loss.backward() | |
| # update weights | |
| self.optimizer.step() | |
| def get_weights(self): | |
| xq = self.model.weight.detach().numpy() | |
| return xq | |
| class SplitLayer(torch.nn.Module): | |
| def forward(self, x): | |
| return torch.split(x, 1, dim=-1) | |