Update agent.py
Browse files
agent.py
CHANGED
|
@@ -49,7 +49,7 @@ def _download_file(file_id: str) -> bytes:
|
|
| 49 |
# --------------------------------------------------------------------------- #
|
| 50 |
class GeminiModel:
|
| 51 |
"""
|
| 52 |
-
Thin adapter around google-genai
|
| 53 |
"""
|
| 54 |
|
| 55 |
def __init__(
|
|
@@ -61,15 +61,15 @@ class GeminiModel:
|
|
| 61 |
api_key = os.getenv("GOOGLE_API_KEY")
|
| 62 |
if not api_key:
|
| 63 |
raise EnvironmentError("GOOGLE_API_KEY is not set.")
|
| 64 |
-
# One client per process is enough
|
| 65 |
self.client = genai.Client(api_key=api_key)
|
| 66 |
self.model_name = model_name
|
| 67 |
self.temperature = temperature
|
| 68 |
self.max_tokens = max_tokens
|
| 69 |
|
| 70 |
-
# ----------
|
| 71 |
def call(self, prompt: str, **kwargs) -> str:
|
| 72 |
-
|
|
|
|
| 73 |
model=self.model_name,
|
| 74 |
contents=prompt,
|
| 75 |
generation_config=gtypes.GenerateContentConfig(
|
|
@@ -77,23 +77,16 @@ class GeminiModel:
|
|
| 77 |
max_output_tokens=self.max_tokens,
|
| 78 |
),
|
| 79 |
)
|
| 80 |
-
return
|
| 81 |
|
| 82 |
-
# ---------- smolagents will use this when messages are present ---------- #
|
| 83 |
def call_messages(self, messages, **kwargs) -> str:
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
# Multimodal path – pass system text first, then structured user parts
|
| 92 |
-
contents = [sys_msg["content"], *user_msg["content"]]
|
| 93 |
-
else:
|
| 94 |
-
# Text prompt path
|
| 95 |
-
contents = f"{sys_msg['content']}\n\n{user_msg['content']}"
|
| 96 |
-
response = self.client.models.generate_content(
|
| 97 |
model=self.model_name,
|
| 98 |
contents=contents,
|
| 99 |
generation_config=gtypes.GenerateContentConfig(
|
|
@@ -101,7 +94,11 @@ class GeminiModel:
|
|
| 101 |
max_output_tokens=self.max_tokens,
|
| 102 |
),
|
| 103 |
)
|
| 104 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
# --------------------------------------------------------------------------- #
|
|
|
|
| 49 |
# --------------------------------------------------------------------------- #
|
| 50 |
class GeminiModel:
|
| 51 |
"""
|
| 52 |
+
Thin adapter around google-genai Client for smolagents.
|
| 53 |
"""
|
| 54 |
|
| 55 |
def __init__(
|
|
|
|
| 61 |
api_key = os.getenv("GOOGLE_API_KEY")
|
| 62 |
if not api_key:
|
| 63 |
raise EnvironmentError("GOOGLE_API_KEY is not set.")
|
|
|
|
| 64 |
self.client = genai.Client(api_key=api_key)
|
| 65 |
self.model_name = model_name
|
| 66 |
self.temperature = temperature
|
| 67 |
self.max_tokens = max_tokens
|
| 68 |
|
| 69 |
+
# ---------- main generation helpers ---------- #
|
| 70 |
def call(self, prompt: str, **kwargs) -> str:
|
| 71 |
+
"""Text-only helper used by __call__."""
|
| 72 |
+
resp = self.client.models.generate_content(
|
| 73 |
model=self.model_name,
|
| 74 |
contents=prompt,
|
| 75 |
generation_config=gtypes.GenerateContentConfig(
|
|
|
|
| 77 |
max_output_tokens=self.max_tokens,
|
| 78 |
),
|
| 79 |
)
|
| 80 |
+
return resp.text.strip()
|
| 81 |
|
|
|
|
| 82 |
def call_messages(self, messages, **kwargs) -> str:
|
| 83 |
+
sys_msg, user_msg = messages
|
| 84 |
+
contents = (
|
| 85 |
+
[sys_msg["content"], *user_msg["content"]]
|
| 86 |
+
if isinstance(user_msg["content"], list)
|
| 87 |
+
else f"{sys_msg['content']}\n\n{user_msg['content']}"
|
| 88 |
+
)
|
| 89 |
+
resp = self.client.models.generate_content(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
model=self.model_name,
|
| 91 |
contents=contents,
|
| 92 |
generation_config=gtypes.GenerateContentConfig(
|
|
|
|
| 94 |
max_output_tokens=self.max_tokens,
|
| 95 |
),
|
| 96 |
)
|
| 97 |
+
return resp.text.strip()
|
| 98 |
+
|
| 99 |
+
# ---------- make the instance itself callable ---------- #
|
| 100 |
+
def __call__(self, prompt: str, **kwargs) -> str: # <-- NEW
|
| 101 |
+
return self.call(prompt, **kwargs)
|
| 102 |
|
| 103 |
|
| 104 |
# --------------------------------------------------------------------------- #
|