Mihaiii commited on
Commit
ff4e513
·
verified ·
1 Parent(s): e9acad5

Upload configuration_ovis.py

Browse files
Files changed (1) hide show
  1. configuration_ovis.py +204 -0
configuration_ovis.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Union, Optional
3
+
4
+ from transformers import PretrainedConfig, AutoConfig, AutoModel
5
+ from .configuration_aimv2 import AIMv2Config
6
+ from .modeling_aimv2 import AIMv2Model
7
+
8
+ IGNORE_ID = -100
9
+ IMAGE_TOKEN_ID = -200
10
+ IMAGE_TOKEN = "<image>"
11
+ IMAGE_ATOM_ID = -300
12
+ IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
13
+
14
+ AutoConfig.register("aimv2", AIMv2Config)
15
+ AutoModel.register(AIMv2Config, AIMv2Model)
16
+
17
+ # ----------------------------------------------------------------------
18
+ # Visual Tokenizer Configuration
19
+ # ----------------------------------------------------------------------
20
+ class BaseVisualTokenizerConfig(PretrainedConfig):
21
+ def __init__(
22
+ self,
23
+ vocab_size=16384,
24
+ tokenize_function="softmax",
25
+ tau=1.0,
26
+ depths=None,
27
+ drop_cls_token=False,
28
+ backbone_config: Optional[Union[PretrainedConfig, dict]] = None,
29
+ hidden_stride: int = 1,
30
+ **kwargs
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.vocab_size = vocab_size
34
+ self.tokenize_function = tokenize_function
35
+ self.tau = tau
36
+ if isinstance(depths, str):
37
+ depths = [int(x) for x in depths.split('|')]
38
+ self.depths = depths
39
+ self.backbone_kwargs = {}
40
+ self.drop_cls_token = drop_cls_token
41
+ if backbone_config is not None:
42
+ assert isinstance(backbone_config, (PretrainedConfig, dict)), \
43
+ f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type"
44
+ if not isinstance(backbone_config, PretrainedConfig):
45
+ model_type = backbone_config['model_type']
46
+ backbone_config.pop('model_type')
47
+ backbone_config = AutoConfig.for_model(model_type, **backbone_config)
48
+ self.backbone_config = backbone_config
49
+ self.hidden_stride = hidden_stride
50
+
51
+
52
+ class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig):
53
+ model_type = "aimv2_visual_tokenizer"
54
+
55
+ def __init__(self, **kwargs):
56
+ super().__init__(**kwargs)
57
+ if self.drop_cls_token:
58
+ self.drop_cls_token = False
59
+ if self.depths:
60
+ assert len(self.depths) == 1
61
+ self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
62
+
63
+
64
+ AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig)
65
+
66
+
67
+ # ----------------------------------------------------------------------
68
+ # Ovis Configuration
69
+ # ----------------------------------------------------------------------
70
+ class OvisConfig(PretrainedConfig):
71
+ model_type = "ovis"
72
+
73
+ def __init__(
74
+ self,
75
+ llm_config: Optional[Union[PretrainedConfig, dict]] = None,
76
+ visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None,
77
+ multimodal_max_length=8192,
78
+ hidden_size=None,
79
+ conversation_formatter_class=None,
80
+ llm_attn_implementation=None,
81
+ disable_tie_weight=False,
82
+ **kwargs
83
+ ):
84
+ super().__init__(**kwargs)
85
+ if llm_config is not None:
86
+ assert isinstance(llm_config, (PretrainedConfig, dict)), \
87
+ f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type"
88
+ if not isinstance(llm_config, PretrainedConfig):
89
+ model_type = llm_config['model_type']
90
+ llm_config.pop('model_type')
91
+ llm_config = AutoConfig.for_model(model_type, **llm_config)
92
+ self.llm_config = llm_config
93
+ if visual_tokenizer_config is not None:
94
+ assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \
95
+ f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type"
96
+ if not isinstance(visual_tokenizer_config, PretrainedConfig):
97
+ model_type = visual_tokenizer_config['model_type']
98
+ visual_tokenizer_config.pop('model_type')
99
+ visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config)
100
+ self.visual_tokenizer_config = visual_tokenizer_config
101
+ self.multimodal_max_length = multimodal_max_length
102
+ self.hidden_size = hidden_size
103
+ self.conversation_formatter_class = conversation_formatter_class
104
+ self.llm_attn_implementation = llm_attn_implementation
105
+ self.disable_tie_weight = disable_tie_weight
106
+
107
+
108
+ # ----------------------------------------------------------------------
109
+ # Conversation Formatter
110
+ # ----------------------------------------------------------------------
111
+ class ConversationFormatter(ABC):
112
+ support_tokenizer_types = None
113
+
114
+ def __init__(self, tokenizer):
115
+ tokenizer_type = type(tokenizer).__name__
116
+ assert tokenizer_type in self.support_tokenizer_types, \
117
+ f'Invalid tokenizer type, expected one from `{self.support_tokenizer_types}`, but got `{tokenizer_type}`'
118
+ self.tokenizer = tokenizer
119
+ self.image_token = IMAGE_TOKEN
120
+ self.image_token_id = IMAGE_TOKEN_ID
121
+ self.ignore_id = IGNORE_ID
122
+
123
+ def _tokenize_with_image_symbol(self, text):
124
+ text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in
125
+ text.split(self.image_token)]
126
+ token_ids = []
127
+ num_chuck = len(text_chunks)
128
+ for i, chunk in enumerate(text_chunks):
129
+ token_ids.extend(chunk)
130
+ if i < num_chuck - 1:
131
+ token_ids.append(self.image_token_id)
132
+ return token_ids
133
+
134
+ @abstractmethod
135
+ def format(self, conversations: List[Dict], generation_preface=None):
136
+ pass
137
+
138
+ @abstractmethod
139
+ def format_query(self, query, generation_preface=""):
140
+ pass
141
+
142
+
143
+ class QwenConversationFormatter(ConversationFormatter):
144
+ support_tokenizer_types = ['QWenTokenizer', 'Qwen2TokenizerFast']
145
+
146
+ def __init__(self, tokenizer):
147
+ super().__init__(tokenizer)
148
+ self.from2role = {
149
+ "system": "<|im_start|>system\n",
150
+ "human": "<|im_start|>user\n",
151
+ "gpt": "<|im_start|>assistant\n",
152
+ }
153
+ self.gpt_token_num = None
154
+ self.im_end = "<|im_end|>\n"
155
+ self.default_system_prompt = "You are a helpful assistant."
156
+
157
+ def format(self, conversations: List[Dict], generation_preface=None):
158
+ if self.gpt_token_num is None:
159
+ self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
160
+
161
+ if conversations[0]["from"] != "system":
162
+ conversations.insert(0, {
163
+ "from": "system",
164
+ "value": self.default_system_prompt
165
+ })
166
+
167
+ if generation_preface is not None:
168
+ conversations.append({
169
+ "from": "gpt",
170
+ "value": generation_preface
171
+ })
172
+
173
+ prompt = ""
174
+ input_ids = []
175
+ labels = []
176
+ num_conversation = len(conversations)
177
+ for i, conversation in enumerate(conversations):
178
+ frm = conversation["from"]
179
+ role = self.from2role[frm]
180
+ message = conversation["value"]
181
+ text = role + message
182
+ if i < num_conversation - 1 or generation_preface is None:
183
+ text += self.im_end
184
+ prompt += text
185
+ token_ids = self._tokenize_with_image_symbol(text)
186
+ input_ids.extend(token_ids)
187
+ label_ids = [self.ignore_id] * len(token_ids)
188
+ if frm == "gpt" and generation_preface is None:
189
+ # learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
190
+ label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
191
+ labels.extend(label_ids)
192
+
193
+ assert self._tokenize_with_image_symbol(prompt) == input_ids
194
+ assert len(input_ids) == len(labels)
195
+
196
+ return prompt, input_ids, labels
197
+
198
+ def format_query(self, query, generation_preface=""):
199
+ prompt, input_ids, _ = self.format([{
200
+ "from": "human",
201
+ "value": query
202
+ }], generation_preface=generation_preface)
203
+
204
+ return prompt, input_ids