File size: 22,407 Bytes
a4b70d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 |
import os
import json
import base64
import time
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import aiohttp
from aiohttp import ClientSession, ClientTimeout
from ...typing import AsyncResult, Messages, MediaListType
from ...errors import MissingAuthError
from ...image.copy_images import save_response_media
from ...image import to_bytes, is_data_an_media
from ...providers.response import Usage, ImageResponse, ToolCalls, Reasoning
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin
from ..helper import get_connector, get_system_prompt, format_media_prompt
from ... import debug
def get_oauth_creds_path():
return Path.home() / ".gemini" / "oauth_creds.json"
class AuthManager(AuthFileMixin):
"""
Handles OAuth2 authentication and Google Code Assist API communication.
Manages token caching, refresh, and API calls.
Requires environment dict-like object with keys:
- GCP_SERVICE_ACCOUNT: JSON string with OAuth2 credentials, containing:
access_token, expiry_date (ms timestamp), refresh_token
- Optionally supports cache storage via a KV storage interface implementing:
get(key) -> value or None,
put(key, value, expiration_seconds),
delete(key)
"""
parent = "GeminiCLI"
OAUTH_REFRESH_URL = "https://oauth2.googleapis.com/token"
OAUTH_CLIENT_ID = "681255809395" + "-oo8ft2oprdrnp9e3aqf6av3hmdib135j" + ".apps.googleusercontent.com"
OAUTH_CLIENT_SECRET = "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"
TOKEN_BUFFER_TIME = 5 * 60 # seconds, 5 minutes
KV_TOKEN_KEY = "oauth_token_cache"
def __init__(self, env: Dict[str, Any]):
self.env = env
self._access_token: Optional[str] = None
self._expiry: Optional[float] = None # Unix timestamp in seconds
self._token_cache = {} # Example in-memory cache; replace with KV store for production
async def initialize_auth(self) -> None:
"""
Initialize authentication by using cached token, or refreshing if needed.
Raises RuntimeError if no valid token can be obtained.
"""
# Try cached token from KV store or in-memory cache
cached = await self._get_cached_token()
now = time.time()
if cached:
expires_at = cached["expiry_date"] / 1000 # ms to seconds
if expires_at - now > self.TOKEN_BUFFER_TIME:
self._access_token = cached["access_token"]
self._expiry = expires_at
return # Use cached token if valid
path = AuthManager.get_cache_file()
if not path.exists():
path = get_oauth_creds_path()
if path.exists():
try:
with path.open("r") as f:
creds = json.load(f)
except Exception as e:
raise RuntimeError(f"Failed to read OAuth credentials from {path}: {e}")
else:
# Parse credentials from environment
if "GCP_SERVICE_ACCOUNT" not in self.env:
raise RuntimeError("GCP_SERVICE_ACCOUNT environment variable not set.")
creds = json.loads(self.env["GCP_SERVICE_ACCOUNT"])
refresh_token = creds.get("refresh_token")
access_token = creds.get("access_token")
expiry_date = creds.get("expiry_date") # milliseconds since epoch
# Use original access token if still valid
if access_token and expiry_date:
expires_at = expiry_date / 1000
if expires_at - now > self.TOKEN_BUFFER_TIME:
self._access_token = access_token
self._expiry = expires_at
await self._cache_token(access_token, expiry_date)
return
# Otherwise, refresh token
if not refresh_token:
raise RuntimeError("No refresh token found in GCP_SERVICE_ACCOUNT.")
await self._refresh_and_cache_token(refresh_token)
async def _refresh_and_cache_token(self, refresh_token: str) -> None:
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {
"client_id": self.OAUTH_CLIENT_ID,
"client_secret": self.OAUTH_CLIENT_SECRET,
"refresh_token": refresh_token,
"grant_type": "refresh_token",
}
async with aiohttp.ClientSession() as session:
async with session.post(self.OAUTH_REFRESH_URL, data=data, headers=headers) as resp:
if resp.status != 200:
text = await resp.text()
raise RuntimeError(f"Token refresh failed: {text}")
resp_data = await resp.json()
access_token = resp_data.get("access_token")
expires_in = resp_data.get("expires_in", 3600) # seconds
if not access_token:
raise RuntimeError("No access_token in refresh response.")
self._access_token = access_token
self._expiry = time.time() + expires_in
expiry_date_ms = int(self._expiry * 1000) # milliseconds
await self._cache_token(access_token, expiry_date_ms)
async def _cache_token(self, access_token: str, expiry_date: int) -> None:
# Cache token in KV store or fallback to memory cache
token_data = {
"access_token": access_token,
"expiry_date": expiry_date,
"cached_at": int(time.time() * 1000), # ms
}
self._token_cache[self.KV_TOKEN_KEY] = token_data
async def _get_cached_token(self) -> Optional[Dict[str, Any]]:
# Return in-memory cached token if present and still valid
cached = self._token_cache.get(self.KV_TOKEN_KEY)
if cached:
expires_at = cached["expiry_date"] / 1000
if expires_at - time.time() > self.TOKEN_BUFFER_TIME:
return cached
return None
async def clear_token_cache(self) -> None:
self._access_token = None
self._expiry = None
def get_access_token(self) -> Optional[str]:
# Return current valid access token or None
if (
self._access_token is not None
and self._expiry is not None
and self._expiry - time.time() > self.TOKEN_BUFFER_TIME
):
return self._access_token
return None
async def call_endpoint(self, method: str, body: Dict[str, Any], is_retry=False) -> Any:
"""
Call Google Code Assist API endpoint with JSON body.
Automatically retries once on 401 Unauthorized by refreshing auth.
"""
if not self.get_access_token():
await self.initialize_auth()
url = f"https://cloudcode-pa.googleapis.com/v1internal:{method}"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.get_access_token()}",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=body) as resp:
if resp.status == 401 and not is_retry:
# Token likely expired, clear and retry once
await self.clear_token_cache()
await self.initialize_auth()
return await self.call_endpoint(method, body, is_retry=True)
elif not resp.ok:
text = await resp.text()
raise RuntimeError(f"API call failed with status {resp.status}: {text}")
return await resp.json()
class GeminiCLIProvider():
url = "https://cloud.google.com/code-assist"
api_base = "https://cloudcode-pa.googleapis.com/v1internal"
# Required for authentication and token management; Expects a compatible AuthManager instance
auth_manager: AuthManager
env: dict
def __init__(self, env: dict, auth_manager: AuthManager):
self.env = env
self.auth_manager = auth_manager
# Cache for discovered project ID
self._project_id: Optional[str] = None
async def discover_project_id(self) -> str:
if self.env.get("GEMINI_PROJECT_ID"):
return self.env["GEMINI_PROJECT_ID"]
if self._project_id:
return self._project_id
try:
load_response = await self.auth_manager.call_endpoint(
"loadCodeAssist",
{
"cloudaicompanionProject": "default-project",
"metadata": {"duetProject": "default-project"},
},
)
project = load_response.get("cloudaicompanionProject")
if project:
self._project_id = project
return project
raise RuntimeError(
"Project ID discovery failed - set GEMINI_PROJECT_ID in environment."
)
except Exception as e:
debug.error(f"Failed to discover project ID: {e}")
raise RuntimeError(
"Could not discover project ID. Ensure authentication or set GEMINI_PROJECT_ID."
)
@staticmethod
def _messages_to_gemini_format(messages: list, media: MediaListType) -> Dict[str, Any]:
format_messages = []
for msg in messages:
# Convert a ChatMessage dict to GeminiFormattedMessage dict
role = "model" if msg["role"] == "assistant" else "user"
# Handle tool role (OpenAI style)
if msg["role"] == "tool":
parts = [
{
"functionResponse": {
"name": msg.get("tool_call_id", "unknown_function"),
"response": {
"result": (
msg["content"]
if isinstance(msg["content"], str)
else json.dumps(msg["content"])
)
},
}
}
],
# Handle assistant messages with tool calls
elif msg["role"] == "assistant" and msg.get("tool_calls"):
parts = []
if isinstance(msg["content"], str) and msg["content"].strip():
parts.append({"text": msg["content"]})
for tool_call in msg["tool_calls"]:
if tool_call.get("type") == "function":
parts.append(
{
"functionCall": {
"name": tool_call["function"]["name"],
"args": json.loads(tool_call["function"]["arguments"]),
}
}
)
# Handle string content
elif isinstance(msg["content"], str):
parts = [{"text": msg["content"]}]
# Handle array content (possibly multimodal)
elif isinstance(msg["content"], list):
for content in msg["content"]:
ctype = content.get("type")
if ctype == "text":
parts.append({"text": content["text"]})
elif ctype == "image_url":
image_url = content.get("image_url", {}).get("url")
if not image_url:
continue
if image_url.startswith("data:"):
# Inline base64 data image
prefix, b64data = image_url.split(",", 1)
mime_type = prefix.split(":")[1].split(";")[0]
parts.append({"inlineData": {"mimeType": mime_type, "data": b64data}})
else:
parts.append(
{
"fileData": {
"mimeType": "image/jpeg", # Could improve by validation
"fileUri": image_url,
}
}
)
else:
parts = [{"text": str(msg["content"])}]
format_messages.append({"role": role, "parts": parts})
if media:
if not format_messages:
format_messages.append({"role": "user", "parts": []})
for media_data, filename in media:
if isinstance(media_data, str):
if not filename:
filename = media_data
extension = filename.split(".")[-1].replace("jpg", "jpeg")
format_messages[-1]["parts"].append(
{
"fileData": {
"mimeType": f"image/{extension}",
"fileUri": image_url,
}
}
)
else:
media_data = to_bytes(media_data)
format_messages[-1]["parts"].append({
"inlineData": {
"mimeType": is_data_an_media(media_data, filename),
"data": base64.b64encode(media_data).decode()
}
})
return format_messages
async def stream_content(
self,
model: str,
messages: Messages,
*,
proxy: Optional[str] = None,
thinking_budget: Optional[int] = None,
tools: Optional[List[dict]] = None,
tool_choice: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[Union[str, List[str]]] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
seed: Optional[int] = None,
response_format: Optional[Dict[str, Any]] = None,
**kwargs
) -> AsyncGenerator:
await self.auth_manager.initialize_auth()
project_id = await self.discover_project_id()
# Convert messages to Gemini format
contents = self._messages_to_gemini_format([m for m in messages if m["role"] not in ["developer", "system"]], media=kwargs.get("media", None))
system_prompt = get_system_prompt(messages)
requestData = {}
if system_prompt:
requestData["system_instruction"] = {"parts": {"text": system_prompt}}
# Compose request body
req_body = {
"model": model,
"project": project_id,
"request": {
"contents": contents,
"generationConfig": {
"maxOutputTokens": max_tokens,
"temperature": temperature,
"topP": top_p,
"stop": stop,
"presencePenalty": presence_penalty,
"frequencyPenalty": frequency_penalty,
"seed": seed,
"responseMimeType": None if response_format is None else ("application/json" if response_format.get("type") == "json_object" else None),
"thinkingConfig": {
"thinkingBudget": thinking_budget,
"includeThoughts": True
} if thinking_budget else None,
},
"tools": tools or [],
"toolConfig": {
"functionCallingConfig": {
"mode": tool_choice.upper(),
"allowedFunctionNames": [tool["function"]["name"] for tool in tools]
}
} if tool_choice else None,
**requestData
},
}
# Remove None values recursively
def clean_none(d):
if isinstance(d, dict):
return {k: clean_none(v) for k, v in d.items() if v is not None}
if isinstance(d, list):
return [clean_none(x) for x in d if x is not None]
return d
req_body = clean_none(req_body)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.auth_manager.get_access_token()}",
}
url = f"{self.api_base}:streamGenerateContent?alt=sse"
# Streaming SSE parsing helper
async def parse_sse_stream(stream: aiohttp.StreamReader) -> AsyncGenerator[Dict[str, Any], None]:
"""Parse SSE stream yielding parsed JSON objects"""
buffer = ""
object_buffer = ""
async for chunk_bytes in stream.iter_any():
chunk = chunk_bytes.decode()
buffer += chunk
lines = buffer.split("\n")
buffer = lines.pop() # Save last incomplete line back
for line in lines:
line = line.strip()
if line == "":
# Empty line indicates end of SSE message -> parse object buffer
if object_buffer:
try:
yield json.loads(object_buffer)
except Exception as e:
debug.error(f"Error parsing SSE JSON: {e}")
object_buffer = ""
elif line.startswith("data: "):
object_buffer += line[6:]
# Final parse when stream ends
if object_buffer:
try:
yield json.loads(object_buffer)
except Exception as e:
debug.error(f"Error parsing final SSE JSON: {e}")
timeout = ClientTimeout(total=None) # No total timeout
connector = get_connector(None, proxy) # Customize connector as needed (supports proxy)
async with ClientSession(headers=headers, timeout=timeout, connector=connector) as session:
async with session.post(url, json=req_body) as resp:
if not resp.ok:
if resp.status == 401:
# Possibly token expired: try login retry logic, omitted here for brevity
raise MissingAuthError(f"Unauthorized (401) from Gemini API")
error_body = await resp.text()
raise RuntimeError(f"Gemini API error {resp.status}: {error_body}")
async for json_data in parse_sse_stream(resp.content):
# Process JSON data according to Gemini API structure
candidates = json_data.get("response", {}).get("candidates", [])
usage_metadata = json_data.get("response", {}).get("usageMetadata", {})
if not candidates:
continue
candidate = candidates[0]
content = candidate.get("content", {})
parts = content.get("parts", [])
tool_calls = []
for part in parts:
# Real thinking chunks
if part.get("thought") is True and "text" in part:
yield Reasoning(part["text"])
# Function calls from Gemini
elif "functionCall" in part:
tool_calls.append(part["functionCall"])
# Text content
elif "text" in part:
yield part["text"]
# Inline media data
elif "inlineData" in part:
# Media chunk - yield media asynchronously
async for media in save_response_media(part["inlineData"], format_media_prompt(messages)):
yield media
# File data (e.g. external image)
elif "fileData" in part:
# Just yield the file URI for now
file_data = part["fileData"]
yield ImageResponse(file_data.get("fileUri"))
if tool_calls:
yield ToolCalls(tool_calls)
if usage_metadata:
yield Usage(
promptTokens=usage_metadata.get("promptTokenCount", 0),
completionTokens=usage_metadata.get("candidatesTokenCount", 0),
)
class GeminiCLI(AsyncGeneratorProvider, ProviderModelMixin):
label = "Google Gemini CLI"
login_url = "https://github.com/GewoonJaap/gemini-cli-openai"
default_model = "gemini-2.5-pro"
models = [
"gemini-2.5-pro",
"gemini-2.5-flash",
]
working = True
supports_message_history = True
supports_system_message = True
needs_auth = True
active_by_default = True
auth_manager: AuthManager = None
@classmethod
def get_models(cls, **kwargs):
if cls.live == 0:
if cls.auth_manager is None:
cls.auth_manager = AuthManager(env=os.environ)
if cls.auth_manager.get_access_token() is not None:
cls.live += 1
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
stream: bool = False,
media: MediaListType = None,
tools: Optional[list] = None,
**kwargs
) -> AsyncResult:
if cls.auth_manager is None:
cls.auth_manager = AuthManager(env=os.environ)
# Initialize Gemini CLI provider with auth manager and environment
provider = GeminiCLIProvider(env=os.environ, auth_manager=cls.auth_manager)
async for chunk in provider.stream_content(
model=model,
messages=messages,
stream=stream,
media=media,
tools=tools,
**kwargs
):
yield chunk |