nielsr HF Staff commited on
Commit
2428a78
·
verified ·
1 Parent(s): c9b2f7b

Add code snippets, metadata tags

Browse files

This PR adds sample code snippets (are these ok or do they look differently for the chat version?), and missing tags for the library_name and pipeline_tag.

Files changed (1) hide show
  1. README.md +122 -0
README.md CHANGED
@@ -1,6 +1,8 @@
1
  ---
2
  language: en
3
  license: mit
 
 
4
  ---
5
  # Kosmos-2.5-chat
6
 
@@ -13,6 +15,126 @@ Kosmos-2.5-chat is a model specifically trained for Visual Question Answering (V
13
 
14
  [Kosmos-2.5: A Multimodal Literate Model](https://arxiv.org/abs/2309.11419)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  ## NOTE:
17
  Since this is a generative model, there is a risk of **hallucination** during the generation process, and it **CAN NOT** guarantee the accuracy of all results in the images.
18
 
 
1
  ---
2
  language: en
3
  license: mit
4
+ library_name: transformers
5
+ pipeline_tag: image-text-to-text
6
  ---
7
  # Kosmos-2.5-chat
8
 
 
15
 
16
  [Kosmos-2.5: A Multimodal Literate Model](https://arxiv.org/abs/2309.11419)
17
 
18
+ ## Usage
19
+
20
+ KOSMOS-2.5 is supported from Transformers >= 4.56. Find the docs [here](https://huggingface.co/docs/transformers/main/en/model_doc/kosmos2_5).
21
+
22
+ ### Image-to-markdown
23
+
24
+ ```python
25
+ import re
26
+ import torch
27
+ import requests
28
+ from PIL import Image, ImageDraw
29
+ from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration, infer_device
30
+
31
+ repo = "microsoft/kosmos-2.5-chat"
32
+ device = f"{infer_device()}:0"
33
+ dtype = torch.bfloat16
34
+ model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, dtype=dtype)
35
+ processor = AutoProcessor.from_pretrained(repo)
36
+
37
+ # sample image
38
+ url = "https://huggingface.co/ydshieh/kosmos-2.5/resolve/main/receipt_00008.png"
39
+ image = Image.open(requests.get(url, stream=True).raw)
40
+
41
+ prompt = "<md>"
42
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
43
+
44
+ height, width = inputs.pop("height"), inputs.pop("width")
45
+ raw_width, raw_height = image.size
46
+ scale_height = raw_height / height
47
+ scale_width = raw_width / width
48
+
49
+ inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
50
+ inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
51
+ generated_ids = model.generate(
52
+ **inputs,
53
+ max_new_tokens=1024,
54
+ )
55
+
56
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
57
+ print(generated_text[0])
58
+ ```
59
+
60
+ ### Image-to-OCR
61
+
62
+ ```python
63
+ import re
64
+ import torch
65
+ import requests
66
+ from PIL import Image, ImageDraw
67
+ from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration, infer_device
68
+
69
+ repo = "microsoft/kosmos-2.5-chat"
70
+ device = f"{infer_device()}:0"
71
+ dtype = torch.bfloat16
72
+ model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, dtype=dtype)
73
+ processor = AutoProcessor.from_pretrained(repo)
74
+
75
+ # sample image
76
+ url = "https://huggingface.co/ydshieh/kosmos-2.5/resolve/main/receipt_00008.png"
77
+ image = Image.open(requests.get(url, stream=True).raw)
78
+
79
+ # bs = 1
80
+ prompt = "<ocr>"
81
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
82
+ height, width = inputs.pop("height"), inputs.pop("width")
83
+ raw_width, raw_height = image.size
84
+ scale_height = raw_height / height
85
+ scale_width = raw_width / width
86
+
87
+ # bs > 1, batch generation
88
+ # inputs = processor(text=[prompt, prompt], images=[image,image], return_tensors="pt")
89
+ # height, width = inputs.pop("height"), inputs.pop("width")
90
+ # raw_width, raw_height = image.size
91
+ # scale_height = raw_height / height[0]
92
+ # scale_width = raw_width / width[0]
93
+
94
+ inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
95
+ inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
96
+ generated_ids = model.generate(
97
+ **inputs,
98
+ max_new_tokens=1024,
99
+ )
100
+
101
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
102
+ def post_process(y, scale_height, scale_width):
103
+ y = y.replace(prompt, "")
104
+ if "<md>" in prompt:
105
+ return y
106
+ pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
107
+ bboxs_raw = re.findall(pattern, y)
108
+ lines = re.split(pattern, y)[1:]
109
+ bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
110
+ bboxs = [[int(j) for j in i] for i in bboxs]
111
+ info = ""
112
+ for i in range(len(lines)):
113
+ box = bboxs[i]
114
+ x0, y0, x1, y1 = box
115
+ if not (x0 >= x1 or y0 >= y1):
116
+ x0 = int(x0 * scale_width)
117
+ y0 = int(y0 * scale_height)
118
+ x1 = int(x1 * scale_width)
119
+ y1 = int(y1 * scale_height)
120
+ info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}"
121
+ return info
122
+
123
+ output_text = post_process(generated_text[0], scale_height, scale_width)
124
+ print(output_text)
125
+
126
+ draw = ImageDraw.Draw(image)
127
+ lines = output_text.split("\n")
128
+ for line in lines:
129
+ # draw the bounding box
130
+ line = list(line.split(","))
131
+ if len(line) < 8:
132
+ continue
133
+ line = list(map(int, line[:8]))
134
+ draw.polygon(line, outline="red")
135
+ image.save("output.png")
136
+ ```
137
+
138
  ## NOTE:
139
  Since this is a generative model, there is a risk of **hallucination** during the generation process, and it **CAN NOT** guarantee the accuracy of all results in the images.
140