Commit 
							
							·
						
						2a79ef4
	
1
								Parent(s):
							
							acef01f
								
Creating captioning pipeline with nucleus sampling
Browse files- pipeline.py +18 -13
    	
        pipeline.py
    CHANGED
    
    | @@ -2,21 +2,28 @@ from typing import  Dict, List, Any | |
| 2 | 
             
            from PIL import Image
         | 
| 3 | 
             
            import requests
         | 
| 4 | 
             
            import torch
         | 
|  | |
| 5 | 
             
            from torchvision import transforms
         | 
| 6 | 
             
            from torchvision.transforms.functional import InterpolationMode
         | 
| 7 |  | 
| 8 | 
             
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 9 |  | 
| 10 | 
            -
            from transformers import pipeline, AutoTokenizer
         | 
| 11 | 
            -
             | 
| 12 |  | 
| 13 | 
             
            class PreTrainedPipeline():
         | 
| 14 | 
             
                def __init__(self, path=""):
         | 
| 15 | 
             
                    # load the optimized model
         | 
| 16 | 
            -
                     | 
| 17 | 
            -
                     | 
| 18 | 
            -
                     | 
| 19 | 
            -
                    self. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 20 |  | 
| 21 |  | 
| 22 | 
             
                def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
         | 
| @@ -29,13 +36,11 @@ class PreTrainedPipeline(): | |
| 29 | 
             
                            - "label": A string representing what the label/class is. There can be multiple labels.
         | 
| 30 | 
             
                            - "score": A score between 0 and 1 describing how confident the model is for this label/class.
         | 
| 31 | 
             
                    """
         | 
| 32 | 
            -
                     | 
| 33 | 
             
                    parameters = data.pop("parameters", None)
         | 
| 34 |  | 
| 35 | 
            -
                     | 
| 36 | 
            -
                     | 
| 37 | 
            -
                         | 
| 38 | 
            -
                    else:
         | 
| 39 | 
            -
                        prediction = self.pipeline(inputs)
         | 
| 40 | 
             
                    # postprocess the prediction
         | 
| 41 | 
            -
                    return  | 
|  | |
| 2 | 
             
            from PIL import Image
         | 
| 3 | 
             
            import requests
         | 
| 4 | 
             
            import torch
         | 
| 5 | 
            +
            from blip import blip_decoder
         | 
| 6 | 
             
            from torchvision import transforms
         | 
| 7 | 
             
            from torchvision.transforms.functional import InterpolationMode
         | 
| 8 |  | 
| 9 | 
             
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 10 |  | 
|  | |
|  | |
| 11 |  | 
| 12 | 
             
            class PreTrainedPipeline():
         | 
| 13 | 
             
                def __init__(self, path=""):
         | 
| 14 | 
             
                    # load the optimized model
         | 
| 15 | 
            +
                    self.model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
         | 
| 16 | 
            +
                    self.model = blip_decoder(pretrained=self.model_url, image_size=384, vit='large')
         | 
| 17 | 
            +
                    self.model.eval()
         | 
| 18 | 
            +
                    self.model = model.to(device)
         | 
| 19 | 
            +
                    
         | 
| 20 | 
            +
                    image_size = 384
         | 
| 21 | 
            +
                    self.transform = transforms.Compose([
         | 
| 22 | 
            +
                        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
         | 
| 23 | 
            +
                        transforms.ToTensor(),
         | 
| 24 | 
            +
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
         | 
| 25 | 
            +
                        ]) 
         | 
| 26 | 
            +
                 
         | 
| 27 |  | 
| 28 |  | 
| 29 | 
             
                def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
         | 
|  | |
| 36 | 
             
                            - "label": A string representing what the label/class is. There can be multiple labels.
         | 
| 37 | 
             
                            - "score": A score between 0 and 1 describing how confident the model is for this label/class.
         | 
| 38 | 
             
                    """
         | 
| 39 | 
            +
                    image = data.pop("inputs", data)
         | 
| 40 | 
             
                    parameters = data.pop("parameters", None)
         | 
| 41 |  | 
| 42 | 
            +
                    image = transform(image).unsqueeze(0).to(device)   
         | 
| 43 | 
            +
                    with torch.no_grad():
         | 
| 44 | 
            +
                        caption = self.model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
         | 
|  | |
|  | |
| 45 | 
             
                    # postprocess the prediction
         | 
| 46 | 
            +
                    return caption
         | 
