dkolarova commited on
Commit
62b92f5
·
verified ·
1 Parent(s): dbd6994

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +846 -0
models.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import random
5
+ from copy import deepcopy
6
+ from dataclasses import asdict, dataclass
7
+ from enum import Enum
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
9
+
10
+ from huggingface_hub import InferenceClient
11
+ from huggingface_hub.utils import is_torch_available
12
+ from PIL import Image
13
+
14
+ from smolagents.tools import Tool
15
+
16
+
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers import StoppingCriteriaList
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
24
+ "type": "regex",
25
+ "value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_code>',
26
+ }
27
+
28
+ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
29
+ "type": "regex",
30
+ "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
31
+ }
32
+
33
+
34
+ def _is_package_available(package_name: str) -> bool:
35
+ try:
36
+ importlib.metadata.version(package_name)
37
+ return True
38
+ except importlib.metadata.PackageNotFoundError:
39
+ return False
40
+
41
+ def encode_image_base64(image):
42
+ buffered = BytesIO()
43
+ image.save(buffered, format="PNG")
44
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
45
+
46
+
47
+ def make_image_url(base64_image):
48
+ return f"data:image/png;base64,{base64_image}"
49
+
50
+ def get_dict_from_nested_dataclasses(obj, ignore_key=None):
51
+ def convert(obj):
52
+ if hasattr(obj, "__dataclass_fields__"):
53
+ return {k: convert(v) for k, v in asdict(obj).items() if k != ignore_key}
54
+ return obj
55
+
56
+ return convert(obj)
57
+
58
+
59
+ @dataclass
60
+ class ChatMessageToolCallDefinition:
61
+ arguments: Any
62
+ name: str
63
+ description: Optional[str] = None
64
+
65
+ @classmethod
66
+ def from_hf_api(cls, tool_call_definition) -> "ChatMessageToolCallDefinition":
67
+ return cls(
68
+ arguments=tool_call_definition.arguments,
69
+ name=tool_call_definition.name,
70
+ description=tool_call_definition.description,
71
+ )
72
+
73
+
74
+ @dataclass
75
+ class ChatMessageToolCall:
76
+ function: ChatMessageToolCallDefinition
77
+ id: str
78
+ type: str
79
+
80
+ @classmethod
81
+ def from_hf_api(cls, tool_call) -> "ChatMessageToolCall":
82
+ return cls(
83
+ function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function),
84
+ id=tool_call.id,
85
+ type=tool_call.type,
86
+ )
87
+
88
+
89
+ @dataclass
90
+ class ChatMessage:
91
+ role: str
92
+ content: Optional[str] = None
93
+ tool_calls: Optional[List[ChatMessageToolCall]] = None
94
+ raw: Optional[Any] = None # Stores the raw output from the API
95
+
96
+ def model_dump_json(self):
97
+ return json.dumps(get_dict_from_nested_dataclasses(self, ignore_key="raw"))
98
+
99
+ @classmethod
100
+ def from_hf_api(cls, message, raw) -> "ChatMessage":
101
+ tool_calls = None
102
+ if getattr(message, "tool_calls", None) is not None:
103
+ tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
104
+ return cls(role=message.role, content=message.content, tool_calls=tool_calls, raw=raw)
105
+
106
+ @classmethod
107
+ def from_dict(cls, data: dict) -> "ChatMessage":
108
+ if data.get("tool_calls"):
109
+ tool_calls = [
110
+ ChatMessageToolCall(
111
+ function=ChatMessageToolCallDefinition(**tc["function"]), id=tc["id"], type=tc["type"]
112
+ )
113
+ for tc in data["tool_calls"]
114
+ ]
115
+ data["tool_calls"] = tool_calls
116
+ return cls(**data)
117
+
118
+ def dict(self):
119
+ return json.dumps(get_dict_from_nested_dataclasses(self))
120
+
121
+
122
+ def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
123
+ if isinstance(arguments, dict):
124
+ return arguments
125
+ else:
126
+ try:
127
+ return json.loads(arguments)
128
+ except Exception:
129
+ return arguments
130
+
131
+
132
+ def parse_tool_args_if_needed(message: ChatMessage) -> ChatMessage:
133
+ for tool_call in message.tool_calls:
134
+ tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
135
+ return message
136
+
137
+
138
+ class MessageRole(str, Enum):
139
+ USER = "user"
140
+ ASSISTANT = "assistant"
141
+ SYSTEM = "system"
142
+ TOOL_CALL = "tool-call"
143
+ TOOL_RESPONSE = "tool-response"
144
+
145
+ @classmethod
146
+ def roles(cls):
147
+ return [r.value for r in cls]
148
+
149
+
150
+ tool_role_conversions = {
151
+ MessageRole.TOOL_CALL: MessageRole.ASSISTANT,
152
+ MessageRole.TOOL_RESPONSE: MessageRole.USER,
153
+ }
154
+
155
+
156
+ def get_tool_json_schema(tool: Tool) -> Dict:
157
+ properties = deepcopy(tool.inputs)
158
+ required = []
159
+ for key, value in properties.items():
160
+ if value["type"] == "any":
161
+ value["type"] = "string"
162
+ if not ("nullable" in value and value["nullable"]):
163
+ required.append(key)
164
+ return {
165
+ "type": "function",
166
+ "function": {
167
+ "name": tool.name,
168
+ "description": tool.description,
169
+ "parameters": {
170
+ "type": "object",
171
+ "properties": properties,
172
+ "required": required,
173
+ },
174
+ },
175
+ }
176
+
177
+
178
+ def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str:
179
+ for stop_seq in stop_sequences:
180
+ if content[-len(stop_seq) :] == stop_seq:
181
+ content = content[: -len(stop_seq)]
182
+ return content
183
+
184
+
185
+ def get_clean_message_list(
186
+ message_list: List[Dict[str, str]],
187
+ role_conversions: Dict[MessageRole, MessageRole] = {},
188
+ convert_images_to_image_urls: bool = False,
189
+ flatten_messages_as_text: bool = False,
190
+ ) -> List[Dict[str, str]]:
191
+ """
192
+ Subsequent messages with the same role will be concatenated to a single message.
193
+ output_message_list is a list of messages that will be used to generate the final message that is chat template compatible with transformers LLM chat template.
194
+
195
+ Args:
196
+ message_list (`list[dict[str, str]]`): List of chat messages.
197
+ role_conversions (`dict[MessageRole, MessageRole]`, *optional* ): Mapping to convert roles.
198
+ convert_images_to_image_urls (`bool`, default `False`): Whether to convert images to image URLs.
199
+ flatten_messages_as_text (`bool`, default `False`): Whether to flatten messages as text.
200
+ """
201
+ output_message_list = []
202
+ message_list = deepcopy(message_list) # Avoid modifying the original list
203
+ for message in message_list:
204
+ role = message["role"]
205
+ if role not in MessageRole.roles():
206
+ raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
207
+
208
+ if role in role_conversions:
209
+ message["role"] = role_conversions[role]
210
+ # encode images if needed
211
+ if isinstance(message["content"], list):
212
+ for element in message["content"]:
213
+ if element["type"] == "image":
214
+ assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}"
215
+ if convert_images_to_image_urls:
216
+ element.update(
217
+ {
218
+ "type": "image_url",
219
+ "image_url": {"url": make_image_url(encode_image_base64(element.pop("image")))},
220
+ }
221
+ )
222
+ else:
223
+ element["image"] = encode_image_base64(element["image"])
224
+
225
+ if len(output_message_list) > 0 and message["role"] == output_message_list[-1]["role"]:
226
+ assert isinstance(message["content"], list), "Error: wrong content:" + str(message["content"])
227
+ if flatten_messages_as_text:
228
+ output_message_list[-1]["content"] += message["content"][0]["text"]
229
+ else:
230
+ output_message_list[-1]["content"] += message["content"]
231
+ else:
232
+ if flatten_messages_as_text:
233
+ content = message["content"][0]["text"]
234
+ else:
235
+ content = message["content"]
236
+ output_message_list.append({"role": message["role"], "content": content})
237
+ return output_message_list
238
+
239
+
240
+ class Model:
241
+ def __init__(self, **kwargs):
242
+ self.last_input_token_count = None
243
+ self.last_output_token_count = None
244
+ self.kwargs = kwargs
245
+
246
+ def _prepare_completion_kwargs(
247
+ self,
248
+ messages: List[Dict[str, str]],
249
+ stop_sequences: Optional[List[str]] = None,
250
+ grammar: Optional[str] = None,
251
+ tools_to_call_from: Optional[List[Tool]] = None,
252
+ custom_role_conversions: Optional[Dict[str, str]] = None,
253
+ convert_images_to_image_urls: bool = False,
254
+ flatten_messages_as_text: bool = False,
255
+ **kwargs,
256
+ ) -> Dict:
257
+ """
258
+ Prepare parameters required for model invocation, handling parameter priorities.
259
+
260
+ Parameter priority from high to low:
261
+ 1. Explicitly passed kwargs
262
+ 2. Specific parameters (stop_sequences, grammar, etc.)
263
+ 3. Default values in self.kwargs
264
+ """
265
+ # Clean and standardize the message list
266
+ messages = get_clean_message_list(
267
+ messages,
268
+ role_conversions=custom_role_conversions or tool_role_conversions,
269
+ convert_images_to_image_urls=convert_images_to_image_urls,
270
+ flatten_messages_as_text=flatten_messages_as_text,
271
+ )
272
+
273
+ # Use self.kwargs as the base configuration
274
+ completion_kwargs = {
275
+ **self.kwargs,
276
+ "messages": messages,
277
+ }
278
+
279
+ # Handle specific parameters
280
+ if stop_sequences is not None:
281
+ completion_kwargs["stop"] = stop_sequences
282
+ if grammar is not None:
283
+ completion_kwargs["grammar"] = grammar
284
+
285
+ # Handle tools parameter
286
+ if tools_to_call_from:
287
+ completion_kwargs.update(
288
+ {
289
+ "tools": [get_tool_json_schema(tool) for tool in tools_to_call_from],
290
+ "tool_choice": "required",
291
+ }
292
+ )
293
+
294
+ # Finally, use the passed-in kwargs to override all settings
295
+ completion_kwargs.update(kwargs)
296
+
297
+ return completion_kwargs
298
+
299
+ def get_token_counts(self) -> Dict[str, int]:
300
+ return {
301
+ "input_token_count": self.last_input_token_count,
302
+ "output_token_count": self.last_output_token_count,
303
+ }
304
+
305
+ def __call__(
306
+ self,
307
+ messages: List[Dict[str, str]],
308
+ stop_sequences: Optional[List[str]] = None,
309
+ grammar: Optional[str] = None,
310
+ tools_to_call_from: Optional[List[Tool]] = None,
311
+ **kwargs,
312
+ ) -> ChatMessage:
313
+ """Process the input messages and return the model's response.
314
+
315
+ Parameters:
316
+ messages (`List[Dict[str, str]]`):
317
+ A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`.
318
+ stop_sequences (`List[str]`, *optional*):
319
+ A list of strings that will stop the generation if encountered in the model's output.
320
+ grammar (`str`, *optional*):
321
+ The grammar or formatting structure to use in the model's response.
322
+ tools_to_call_from (`List[Tool]`, *optional*):
323
+ A list of tools that the model can use to generate responses.
324
+ **kwargs:
325
+ Additional keyword arguments to be passed to the underlying model.
326
+
327
+ Returns:
328
+ `ChatMessage`: A chat message object containing the model's response.
329
+ """
330
+ pass # To be implemented in child classes!
331
+
332
+
333
+ class HfApiModel(Model):
334
+ """A class to interact with Hugging Face's Inference API for language model interaction.
335
+
336
+ This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
337
+
338
+ Parameters:
339
+ model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
340
+ The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
341
+ provider (`str`, *optional*):
342
+ Name of the provider to use for inference. Can be `"replicate"`, `"together"`, `"fal-ai"`, `"sambanova"` or `"hf-inference"`.
343
+ defaults to hf-inference (HF Inference API).
344
+ token (`str`, *optional*):
345
+ Token used by the Hugging Face API for authentication. This token need to be authorized 'Make calls to the serverless Inference API'.
346
+ If the model is gated (like Llama-3 models), the token also needs 'Read access to contents of all public gated repos you can access'.
347
+ If not provided, the class will try to use environment variable 'HF_TOKEN', else use the token stored in the Hugging Face CLI configuration.
348
+ timeout (`int`, *optional*, defaults to 120):
349
+ Timeout for the API request, in seconds.
350
+ custom_role_conversions (`dict[str, str]`, *optional*):
351
+ Custom role conversion mapping to convert message roles in others.
352
+ Useful for specific models that do not support specific message roles like "system".
353
+ **kwargs:
354
+ Additional keyword arguments to pass to the Hugging Face API.
355
+
356
+ Raises:
357
+ ValueError:
358
+ If the model name is not provided.
359
+
360
+ Example:
361
+ ```python
362
+ >>> engine = HfApiModel(
363
+ ... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
364
+ ... token="your_hf_token_here",
365
+ ... max_tokens=5000,
366
+ ... )
367
+ >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
368
+ >>> response = engine(messages, stop_sequences=["END"])
369
+ >>> print(response)
370
+ "Quantum mechanics is the branch of physics that studies..."
371
+ ```
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
377
+ provider: Optional[str] = None,
378
+ token: Optional[str] = None,
379
+ timeout: Optional[int] = 120,
380
+ custom_role_conversions: Optional[Dict[str, str]] = None,
381
+ **kwargs,
382
+ ):
383
+ super().__init__(**kwargs)
384
+ self.model_id = model_id
385
+ self.provider = provider
386
+ if token is None:
387
+ token = os.getenv("HF_TOKEN")
388
+ self.client = InferenceClient(self.model_id, provider=provider, token=token, timeout=timeout)
389
+ self.custom_role_conversions = custom_role_conversions
390
+
391
+ def __call__(
392
+ self,
393
+ messages: List[Dict[str, str]],
394
+ stop_sequences: Optional[List[str]] = None,
395
+ grammar: Optional[str] = None,
396
+ tools_to_call_from: Optional[List[Tool]] = None,
397
+ **kwargs,
398
+ ) -> ChatMessage:
399
+ completion_kwargs = self._prepare_completion_kwargs(
400
+ messages=messages,
401
+ stop_sequences=stop_sequences,
402
+ grammar=grammar,
403
+ tools_to_call_from=tools_to_call_from,
404
+ convert_images_to_image_urls=True,
405
+ custom_role_conversions=self.custom_role_conversions,
406
+ **kwargs,
407
+ )
408
+ response = self.client.chat_completion(**completion_kwargs)
409
+
410
+ self.last_input_token_count = response.usage.prompt_tokens
411
+ self.last_output_token_count = response.usage.completion_tokens
412
+ message = ChatMessage.from_hf_api(response.choices[0].message, raw=response)
413
+ if tools_to_call_from is not None:
414
+ return parse_tool_args_if_needed(message)
415
+ return message
416
+
417
+
418
+ class TransformersModel(Model):
419
+ """A class that uses Hugging Face's Transformers library for language model interaction.
420
+
421
+ This model allows you to load and use Hugging Face's models locally using the Transformers library. It supports features like stop sequences and grammar customization.
422
+
423
+ > [!TIP]
424
+ > You must have `transformers` and `torch` installed on your machine. Please run `pip install smolagents[transformers]` if it's not the case.
425
+
426
+ Parameters:
427
+ model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
428
+ The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
429
+ device_map (`str`, *optional*):
430
+ The device_map to initialize your model with.
431
+ torch_dtype (`str`, *optional*):
432
+ The torch_dtype to initialize your model with.
433
+ trust_remote_code (bool, default `False`):
434
+ Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
435
+ kwargs (dict, *optional*):
436
+ Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
437
+ **kwargs:
438
+ Additional keyword arguments to pass to `model.generate()`, for instance `max_new_tokens` or `device`.
439
+ Raises:
440
+ ValueError:
441
+ If the model name is not provided.
442
+
443
+ Example:
444
+ ```python
445
+ >>> engine = TransformersModel(
446
+ ... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
447
+ ... device="cuda",
448
+ ... max_new_tokens=5000,
449
+ ... )
450
+ >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
451
+ >>> response = engine(messages, stop_sequences=["END"])
452
+ >>> print(response)
453
+ "Quantum mechanics is the branch of physics that studies..."
454
+ ```
455
+ """
456
+
457
+ def __init__(
458
+ self,
459
+ model_id: Optional[str] = None,
460
+ device_map: Optional[str] = None,
461
+ torch_dtype: Optional[str] = None,
462
+ trust_remote_code: bool = False,
463
+ **kwargs,
464
+ ):
465
+ super().__init__(**kwargs)
466
+ if not is_torch_available() or not _is_package_available("transformers"):
467
+ raise ModuleNotFoundError(
468
+ "Please install 'transformers' extra to use 'TransformersModel': `pip install 'smolagents[transformers]'`"
469
+ )
470
+ import torch
471
+ from transformers import AutoModelForCausalLM, AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
472
+
473
+ default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
474
+ if model_id is None:
475
+ model_id = default_model_id
476
+ logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'")
477
+ self.model_id = model_id
478
+ self.kwargs = kwargs
479
+ if device_map is None:
480
+ device_map = "cuda" if torch.cuda.is_available() else "cpu"
481
+ logger.info(f"Using device: {device_map}")
482
+ self._is_vlm = False
483
+ try:
484
+ self.model = AutoModelForCausalLM.from_pretrained(
485
+ model_id,
486
+ device_map=device_map,
487
+ torch_dtype=torch_dtype,
488
+ trust_remote_code=trust_remote_code,
489
+ )
490
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
491
+ except ValueError as e:
492
+ if "Unrecognized configuration class" in str(e):
493
+ self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=device_map)
494
+ self.processor = AutoProcessor.from_pretrained(model_id)
495
+ self._is_vlm = True
496
+ else:
497
+ raise e
498
+ except Exception as e:
499
+ logger.warning(
500
+ f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
501
+ )
502
+ self.model_id = default_model_id
503
+ self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
504
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
505
+
506
+ def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList":
507
+ from transformers import StoppingCriteria, StoppingCriteriaList
508
+
509
+ class StopOnStrings(StoppingCriteria):
510
+ def __init__(self, stop_strings: List[str], tokenizer):
511
+ self.stop_strings = stop_strings
512
+ self.tokenizer = tokenizer
513
+ self.stream = ""
514
+
515
+ def reset(self):
516
+ self.stream = ""
517
+
518
+ def __call__(self, input_ids, scores, **kwargs):
519
+ generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
520
+ self.stream += generated
521
+ if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
522
+ return True
523
+ return False
524
+
525
+ return StoppingCriteriaList([StopOnStrings(stop_sequences, tokenizer)])
526
+
527
+ def __call__(
528
+ self,
529
+ messages: List[Dict[str, str]],
530
+ stop_sequences: Optional[List[str]] = None,
531
+ grammar: Optional[str] = None,
532
+ tools_to_call_from: Optional[List[Tool]] = None,
533
+ images: Optional[List[Image.Image]] = None,
534
+ **kwargs,
535
+ ) -> ChatMessage:
536
+ completion_kwargs = self._prepare_completion_kwargs(
537
+ messages=messages,
538
+ stop_sequences=stop_sequences,
539
+ grammar=grammar,
540
+ flatten_messages_as_text=(not self._is_vlm),
541
+ **kwargs,
542
+ )
543
+
544
+ messages = completion_kwargs.pop("messages")
545
+ stop_sequences = completion_kwargs.pop("stop", None)
546
+
547
+ max_new_tokens = (
548
+ kwargs.get("max_new_tokens")
549
+ or kwargs.get("max_tokens")
550
+ or self.kwargs.get("max_new_tokens")
551
+ or self.kwargs.get("max_tokens")
552
+ )
553
+
554
+ if max_new_tokens:
555
+ completion_kwargs["max_new_tokens"] = max_new_tokens
556
+
557
+ if hasattr(self, "processor"):
558
+ images = [Image.open(image) for image in images] if images else None
559
+ prompt_tensor = self.processor.apply_chat_template(
560
+ messages,
561
+ tools=[get_tool_json_schema(tool) for tool in tools_to_call_from] if tools_to_call_from else None,
562
+ return_tensors="pt",
563
+ tokenize=True,
564
+ return_dict=True,
565
+ images=images,
566
+ add_generation_prompt=True if tools_to_call_from else False,
567
+ )
568
+ else:
569
+ prompt_tensor = self.tokenizer.apply_chat_template(
570
+ messages,
571
+ tools=[get_tool_json_schema(tool) for tool in tools_to_call_from] if tools_to_call_from else None,
572
+ return_tensors="pt",
573
+ return_dict=True,
574
+ add_generation_prompt=True if tools_to_call_from else False,
575
+ )
576
+
577
+ prompt_tensor = prompt_tensor.to(self.model.device)
578
+ count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
579
+
580
+ if stop_sequences:
581
+ stopping_criteria = self.make_stopping_criteria(
582
+ stop_sequences, tokenizer=self.processor if hasattr(self, "processor") else self.tokenizer
583
+ )
584
+ else:
585
+ stopping_criteria = None
586
+
587
+ out = self.model.generate(
588
+ **prompt_tensor,
589
+ stopping_criteria=stopping_criteria,
590
+ **completion_kwargs,
591
+ )
592
+ generated_tokens = out[0, count_prompt_tokens:]
593
+ if hasattr(self, "processor"):
594
+ output = self.processor.decode(generated_tokens, skip_special_tokens=True)
595
+ else:
596
+ output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
597
+ self.last_input_token_count = count_prompt_tokens
598
+ self.last_output_token_count = len(generated_tokens)
599
+
600
+ if stop_sequences is not None:
601
+ output = remove_stop_sequences(output, stop_sequences)
602
+
603
+ if tools_to_call_from is None:
604
+ return ChatMessage(
605
+ role="assistant",
606
+ content=output,
607
+ raw={"out": out, "completion_kwargs": completion_kwargs},
608
+ )
609
+ else:
610
+ if "Action:" in output:
611
+ output = output.split("Action:", 1)[1].strip()
612
+ try:
613
+ start_index = output.index("{")
614
+ end_index = output.rindex("}")
615
+ output = output[start_index : end_index + 1]
616
+ except Exception as e:
617
+ raise Exception("No json blob found in output!") from e
618
+
619
+ try:
620
+ parsed_output = json.loads(output)
621
+ except json.JSONDecodeError as e:
622
+ raise ValueError(f"Tool call '{output}' has an invalid JSON structure: {e}")
623
+ tool_name = parsed_output.get("name")
624
+ tool_arguments = parsed_output.get("arguments")
625
+ return ChatMessage(
626
+ role="assistant",
627
+ content="",
628
+ tool_calls=[
629
+ ChatMessageToolCall(
630
+ id="".join(random.choices("0123456789", k=5)),
631
+ type="function",
632
+ function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
633
+ )
634
+ ],
635
+ raw={"out": out, "completion_kwargs": completion_kwargs},
636
+ )
637
+
638
+
639
+ class LiteLLMModel(Model):
640
+ """This model connects to [LiteLLM](https://www.litellm.ai/) as a gateway to hundreds of LLMs.
641
+
642
+ Parameters:
643
+ model_id (`str`):
644
+ The model identifier to use on the server (e.g. "gpt-3.5-turbo").
645
+ api_base (`str`, *optional*):
646
+ The base URL of the OpenAI-compatible API server.
647
+ api_key (`str`, *optional*):
648
+ The API key to use for authentication.
649
+ custom_role_conversions (`dict[str, str]`, *optional*):
650
+ Custom role conversion mapping to convert message roles in others.
651
+ Useful for specific models that do not support specific message roles like "system".
652
+ **kwargs:
653
+ Additional keyword arguments to pass to the OpenAI API.
654
+ """
655
+
656
+ def __init__(
657
+ self,
658
+ model_id: str = "anthropic/claude-3-5-sonnet-20240620",
659
+ api_base=None,
660
+ api_key=None,
661
+ custom_role_conversions: Optional[Dict[str, str]] = None,
662
+ **kwargs,
663
+ ):
664
+ try:
665
+ import litellm
666
+ except ModuleNotFoundError:
667
+ raise ModuleNotFoundError(
668
+ "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
669
+ )
670
+
671
+ super().__init__(**kwargs)
672
+ self.model_id = model_id
673
+ # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
674
+ litellm.add_function_to_prompt = True
675
+ self.api_base = api_base
676
+ self.api_key = api_key
677
+ self.custom_role_conversions = custom_role_conversions
678
+
679
+ def __call__(
680
+ self,
681
+ messages: List[Dict[str, str]],
682
+ stop_sequences: Optional[List[str]] = None,
683
+ grammar: Optional[str] = None,
684
+ tools_to_call_from: Optional[List[Tool]] = None,
685
+ **kwargs,
686
+ ) -> ChatMessage:
687
+ import litellm
688
+
689
+ completion_kwargs = self._prepare_completion_kwargs(
690
+ messages=messages,
691
+ stop_sequences=stop_sequences,
692
+ grammar=grammar,
693
+ tools_to_call_from=tools_to_call_from,
694
+ model=self.model_id,
695
+ api_base=self.api_base,
696
+ api_key=self.api_key,
697
+ convert_images_to_image_urls=True,
698
+ flatten_messages_as_text=self.model_id.startswith("ollama"),
699
+ custom_role_conversions=self.custom_role_conversions,
700
+ **kwargs,
701
+ )
702
+
703
+ response = litellm.completion(**completion_kwargs)
704
+
705
+ self.last_input_token_count = response.usage.prompt_tokens
706
+ self.last_output_token_count = response.usage.completion_tokens
707
+ message = ChatMessage.from_dict(
708
+ response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
709
+ )
710
+ message.raw = response
711
+
712
+ if tools_to_call_from is not None:
713
+ return parse_tool_args_if_needed(message)
714
+ return message
715
+
716
+
717
+ class OpenAIServerModel(Model):
718
+ """This model connects to an OpenAI-compatible API server.
719
+
720
+ Parameters:
721
+ model_id (`str`):
722
+ The model identifier to use on the server (e.g. "gpt-3.5-turbo").
723
+ api_base (`str`, *optional*):
724
+ The base URL of the OpenAI-compatible API server.
725
+ api_key (`str`, *optional*):
726
+ The API key to use for authentication.
727
+ organization (`str`, *optional*):
728
+ The organization to use for the API request.
729
+ project (`str`, *optional*):
730
+ The project to use for the API request.
731
+ custom_role_conversions (`dict[str, str]`, *optional*):
732
+ Custom role conversion mapping to convert message roles in others.
733
+ Useful for specific models that do not support specific message roles like "system".
734
+ **kwargs:
735
+ Additional keyword arguments to pass to the OpenAI API.
736
+ """
737
+
738
+ def __init__(
739
+ self,
740
+ model_id: str,
741
+ api_base: Optional[str] = None,
742
+ api_key: Optional[str] = None,
743
+ organization: Optional[str] | None = None,
744
+ project: Optional[str] | None = None,
745
+ custom_role_conversions: Optional[Dict[str, str]] = None,
746
+ **kwargs,
747
+ ):
748
+ try:
749
+ import openai
750
+ except ModuleNotFoundError:
751
+ raise ModuleNotFoundError(
752
+ "Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`"
753
+ ) from None
754
+
755
+ super().__init__(**kwargs)
756
+ self.model_id = model_id
757
+ self.client = openai.OpenAI(
758
+ base_url=api_base,
759
+ api_key=api_key,
760
+ organization=organization,
761
+ project=project,
762
+ )
763
+ self.custom_role_conversions = custom_role_conversions
764
+
765
+ def __call__(
766
+ self,
767
+ messages: List[Dict[str, str]],
768
+ stop_sequences: Optional[List[str]] = None,
769
+ grammar: Optional[str] = None,
770
+ tools_to_call_from: Optional[List[Tool]] = None,
771
+ **kwargs,
772
+ ) -> ChatMessage:
773
+ completion_kwargs = self._prepare_completion_kwargs(
774
+ messages=messages,
775
+ stop_sequences=stop_sequences,
776
+ grammar=grammar,
777
+ tools_to_call_from=tools_to_call_from,
778
+ model=self.model_id,
779
+ custom_role_conversions=self.custom_role_conversions,
780
+ convert_images_to_image_urls=True,
781
+ **kwargs,
782
+ )
783
+ response = self.client.chat.completions.create(**completion_kwargs)
784
+ self.last_input_token_count = response.usage.prompt_tokens
785
+ self.last_output_token_count = response.usage.completion_tokens
786
+
787
+ message = ChatMessage.from_dict(
788
+ response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
789
+ )
790
+ message.raw = response
791
+ if tools_to_call_from is not None:
792
+ return parse_tool_args_if_needed(message)
793
+ return message
794
+
795
+
796
+ class AzureOpenAIServerModel(OpenAIServerModel):
797
+ """This model connects to an Azure OpenAI deployment.
798
+
799
+ Parameters:
800
+ model_id (`str`):
801
+ The model deployment name to use when connecting (e.g. "gpt-4o-mini").
802
+ azure_endpoint (`str`, *optional*):
803
+ The Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`. If not provided, it will be inferred from the `AZURE_OPENAI_ENDPOINT` environment variable.
804
+ api_key (`str`, *optional*):
805
+ The API key to use for authentication. If not provided, it will be inferred from the `AZURE_OPENAI_API_KEY` environment variable.
806
+ api_version (`str`, *optional*):
807
+ The API version to use. If not provided, it will be inferred from the `OPENAI_API_VERSION` environment variable.
808
+ custom_role_conversions (`dict[str, str]`, *optional*):
809
+ Custom role conversion mapping to convert message roles in others.
810
+ Useful for specific models that do not support specific message roles like "system".
811
+ **kwargs:
812
+ Additional keyword arguments to pass to the Azure OpenAI API.
813
+ """
814
+
815
+ def __init__(
816
+ self,
817
+ model_id: str,
818
+ azure_endpoint: Optional[str] = None,
819
+ api_key: Optional[str] = None,
820
+ api_version: Optional[str] = None,
821
+ custom_role_conversions: Optional[Dict[str, str]] = None,
822
+ **kwargs,
823
+ ):
824
+ # read the api key manually, to avoid super().__init__() trying to use the wrong api_key (OPENAI_API_KEY)
825
+ if api_key is None:
826
+ api_key = os.environ.get("AZURE_OPENAI_API_KEY")
827
+
828
+ super().__init__(model_id=model_id, api_key=api_key, custom_role_conversions=custom_role_conversions, **kwargs)
829
+ # if we've reached this point, it means the openai package is available (checked in baseclass) so go ahead and import it
830
+ import openai
831
+
832
+ self.client = openai.AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint)
833
+
834
+
835
+ __all__ = [
836
+ "MessageRole",
837
+ "tool_role_conversions",
838
+ "get_clean_message_list",
839
+ "Model",
840
+ "TransformersModel",
841
+ "HfApiModel",
842
+ "LiteLLMModel",
843
+ "OpenAIServerModel",
844
+ "AzureOpenAIServerModel",
845
+ "ChatMessage",
846
+ ]