Add code snippets, library name and pipeline tag (#18)
Browse files- Add code snippets, library name and pipeline tag (21ab46e9261f3c32615e400e3c3f7e67cd227ddf)
Co-authored-by: Niels Rogge <nielsr@users.noreply.huggingface.co>
    	
        README.md
    CHANGED
    
    | @@ -1,6 +1,8 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            language: en
         | 
| 3 | 
             
            license: mit
         | 
|  | |
|  | |
| 4 | 
             
            ---
         | 
| 5 | 
             
            # Kosmos-2.5
         | 
| 6 |  | 
| @@ -16,10 +18,125 @@ Kosmos-2.5 is a multimodal literate model for machine reading of text-intensive | |
| 16 | 
             
            Since this is a generative model, there is a risk of **hallucination** during the generation process, and it **CAN NOT** guarantee the accuracy of all OCR/Markdown results in the images.
         | 
| 17 |  | 
| 18 | 
             
            ## Inference
         | 
|  | |
|  | |
|  | |
| 19 | 
             
            **Markdown Task:** For usage instructions, please refer to [md.py](md.py).
         | 
| 20 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 21 | 
             
            **OCR Task:** For usage instructions, please refer to [ocr.py](ocr.py).
         | 
| 22 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 23 | 
             
            ## Citation
         | 
| 24 |  | 
| 25 | 
             
            If you find Kosmos-2.5 useful in your research, please cite the following paper:
         | 
| @@ -36,7 +153,4 @@ If you find Kosmos-2.5 useful in your research, please cite the following paper: | |
| 36 | 
             
            ## License
         | 
| 37 | 
             
            The content of this project itself is licensed under the [MIT](https://github.com/microsoft/unilm/blob/master/kosmos-2.5/LICENSE)
         | 
| 38 |  | 
| 39 | 
            -
            [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
         | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            language: en
         | 
| 3 | 
             
            license: mit
         | 
| 4 | 
            +
            library_name: transformers
         | 
| 5 | 
            +
            pipeline_tag: image-text-to-text
         | 
| 6 | 
             
            ---
         | 
| 7 | 
             
            # Kosmos-2.5
         | 
| 8 |  | 
|  | |
| 18 | 
             
            Since this is a generative model, there is a risk of **hallucination** during the generation process, and it **CAN NOT** guarantee the accuracy of all OCR/Markdown results in the images.
         | 
| 19 |  | 
| 20 | 
             
            ## Inference
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            KOSMOS-2.5 is supported from Transformers >= 4.56. Find the docs [here](https://huggingface.co/docs/transformers/main/en/model_doc/kosmos2_5).
         | 
| 23 | 
            +
             | 
| 24 | 
             
            **Markdown Task:** For usage instructions, please refer to [md.py](md.py).
         | 
| 25 |  | 
| 26 | 
            +
            ```py
         | 
| 27 | 
            +
            import re
         | 
| 28 | 
            +
            import torch
         | 
| 29 | 
            +
            import requests
         | 
| 30 | 
            +
            from PIL import Image, ImageDraw
         | 
| 31 | 
            +
            from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration, infer_device
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            repo = "microsoft/kosmos-2.5"
         | 
| 34 | 
            +
            device = "cuda:0"
         | 
| 35 | 
            +
            dtype = torch.bfloat16
         | 
| 36 | 
            +
            model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, dtype=dtype)
         | 
| 37 | 
            +
            processor = AutoProcessor.from_pretrained(repo)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            # sample image
         | 
| 40 | 
            +
            url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"
         | 
| 41 | 
            +
            image = Image.open(requests.get(url, stream=True).raw)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            prompt = "<md>"
         | 
| 44 | 
            +
            inputs = processor(text=prompt, images=image, return_tensors="pt")
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            height, width = inputs.pop("height"), inputs.pop("width")
         | 
| 47 | 
            +
            raw_width, raw_height = image.size
         | 
| 48 | 
            +
            scale_height = raw_height / height
         | 
| 49 | 
            +
            scale_width = raw_width / width
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
         | 
| 52 | 
            +
            inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
         | 
| 53 | 
            +
            generated_ids = model.generate(
         | 
| 54 | 
            +
                **inputs,
         | 
| 55 | 
            +
                max_new_tokens=1024,
         | 
| 56 | 
            +
            )
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
         | 
| 59 | 
            +
            print(generated_text[0])
         | 
| 60 | 
            +
            ```
         | 
| 61 | 
            +
             | 
| 62 | 
             
            **OCR Task:** For usage instructions, please refer to [ocr.py](ocr.py).
         | 
| 63 |  | 
| 64 | 
            +
            ```py
         | 
| 65 | 
            +
            import re
         | 
| 66 | 
            +
            import torch
         | 
| 67 | 
            +
            import requests
         | 
| 68 | 
            +
            from PIL import Image, ImageDraw
         | 
| 69 | 
            +
            from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration, infer_device
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            repo = "microsoft/kosmos-2.5"
         | 
| 72 | 
            +
            device = "cuda:0"
         | 
| 73 | 
            +
            dtype = torch.bfloat16
         | 
| 74 | 
            +
            model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, dtype=dtype)
         | 
| 75 | 
            +
            processor = AutoProcessor.from_pretrained(repo)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            # sample image
         | 
| 78 | 
            +
            url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"
         | 
| 79 | 
            +
            image = Image.open(requests.get(url, stream=True).raw)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            # bs = 1
         | 
| 82 | 
            +
            prompt = "<ocr>"
         | 
| 83 | 
            +
            inputs = processor(text=prompt, images=image, return_tensors="pt")
         | 
| 84 | 
            +
            height, width = inputs.pop("height"), inputs.pop("width")
         | 
| 85 | 
            +
            raw_width, raw_height = image.size
         | 
| 86 | 
            +
            scale_height = raw_height / height
         | 
| 87 | 
            +
            scale_width = raw_width / width
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            # bs > 1, batch generation
         | 
| 90 | 
            +
            # inputs = processor(text=[prompt, prompt], images=[image,image], return_tensors="pt")
         | 
| 91 | 
            +
            # height, width = inputs.pop("height"), inputs.pop("width")
         | 
| 92 | 
            +
            # raw_width, raw_height = image.size
         | 
| 93 | 
            +
            # scale_height = raw_height / height[0]
         | 
| 94 | 
            +
            # scale_width = raw_width / width[0]
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
         | 
| 97 | 
            +
            inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
         | 
| 98 | 
            +
            generated_ids = model.generate(
         | 
| 99 | 
            +
                **inputs,
         | 
| 100 | 
            +
                max_new_tokens=1024,
         | 
| 101 | 
            +
            )
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
         | 
| 104 | 
            +
            def post_process(y, scale_height, scale_width):
         | 
| 105 | 
            +
                y = y.replace(prompt, "")
         | 
| 106 | 
            +
                if "<md>" in prompt:
         | 
| 107 | 
            +
                    return y
         | 
| 108 | 
            +
                pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
         | 
| 109 | 
            +
                bboxs_raw = re.findall(pattern, y)
         | 
| 110 | 
            +
                lines = re.split(pattern, y)[1:]
         | 
| 111 | 
            +
                bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
         | 
| 112 | 
            +
                bboxs = [[int(j) for j in i] for i in bboxs]
         | 
| 113 | 
            +
                info = ""
         | 
| 114 | 
            +
                for i in range(len(lines)):
         | 
| 115 | 
            +
                    box = bboxs[i]
         | 
| 116 | 
            +
                    x0, y0, x1, y1 = box
         | 
| 117 | 
            +
                    if not (x0 >= x1 or y0 >= y1):
         | 
| 118 | 
            +
                        x0 = int(x0 * scale_width)
         | 
| 119 | 
            +
                        y0 = int(y0 * scale_height)
         | 
| 120 | 
            +
                        x1 = int(x1 * scale_width)
         | 
| 121 | 
            +
                        y1 = int(y1 * scale_height)
         | 
| 122 | 
            +
                        info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}"
         | 
| 123 | 
            +
                return info
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            output_text = post_process(generated_text[0], scale_height, scale_width)
         | 
| 126 | 
            +
            print(output_text)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            draw = ImageDraw.Draw(image)
         | 
| 129 | 
            +
            lines = output_text.split("\n")
         | 
| 130 | 
            +
            for line in lines:
         | 
| 131 | 
            +
                # draw the bounding box
         | 
| 132 | 
            +
                line = list(line.split(","))
         | 
| 133 | 
            +
                if len(line) < 8:
         | 
| 134 | 
            +
                    continue
         | 
| 135 | 
            +
                line = list(map(int, line[:8]))
         | 
| 136 | 
            +
                draw.polygon(line, outline="red")
         | 
| 137 | 
            +
            image.save("output.png")
         | 
| 138 | 
            +
            ```
         | 
| 139 | 
            +
             | 
| 140 | 
             
            ## Citation
         | 
| 141 |  | 
| 142 | 
             
            If you find Kosmos-2.5 useful in your research, please cite the following paper:
         | 
|  | |
| 153 | 
             
            ## License
         | 
| 154 | 
             
            The content of this project itself is licensed under the [MIT](https://github.com/microsoft/unilm/blob/master/kosmos-2.5/LICENSE)
         | 
| 155 |  | 
| 156 | 
            +
            [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
         | 
|  | |
|  | |
|  | 

