|
|
import os |
|
|
import json |
|
|
import base64 |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union |
|
|
|
|
|
import aiohttp |
|
|
from aiohttp import ClientSession, ClientTimeout |
|
|
|
|
|
from ...typing import AsyncResult, Messages, MediaListType |
|
|
from ...errors import MissingAuthError |
|
|
from ...image.copy_images import save_response_media |
|
|
from ...image import to_bytes, is_data_an_media |
|
|
from ...providers.response import Usage, ImageResponse, ToolCalls, Reasoning |
|
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin |
|
|
from ..helper import get_connector, get_system_prompt, format_media_prompt |
|
|
from ... import debug |
|
|
|
|
|
def get_oauth_creds_path(): |
|
|
return Path.home() / ".gemini" / "oauth_creds.json" |
|
|
|
|
|
class AuthManager(AuthFileMixin): |
|
|
""" |
|
|
Handles OAuth2 authentication and Google Code Assist API communication. |
|
|
Manages token caching, refresh, and API calls. |
|
|
|
|
|
Requires environment dict-like object with keys: |
|
|
- GCP_SERVICE_ACCOUNT: JSON string with OAuth2 credentials, containing: |
|
|
access_token, expiry_date (ms timestamp), refresh_token |
|
|
- Optionally supports cache storage via a KV storage interface implementing: |
|
|
get(key) -> value or None, |
|
|
put(key, value, expiration_seconds), |
|
|
delete(key) |
|
|
""" |
|
|
parent = "GeminiCLI" |
|
|
|
|
|
OAUTH_REFRESH_URL = "https://oauth2.googleapis.com/token" |
|
|
OAUTH_CLIENT_ID = "681255809395" + "-oo8ft2oprdrnp9e3aqf6av3hmdib135j" + ".apps.googleusercontent.com" |
|
|
OAUTH_CLIENT_SECRET = "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" |
|
|
TOKEN_BUFFER_TIME = 5 * 60 |
|
|
KV_TOKEN_KEY = "oauth_token_cache" |
|
|
|
|
|
def __init__(self, env: Dict[str, Any]): |
|
|
self.env = env |
|
|
self._access_token: Optional[str] = None |
|
|
self._expiry: Optional[float] = None |
|
|
self._token_cache = {} |
|
|
|
|
|
async def initialize_auth(self) -> None: |
|
|
""" |
|
|
Initialize authentication by using cached token, or refreshing if needed. |
|
|
Raises RuntimeError if no valid token can be obtained. |
|
|
""" |
|
|
|
|
|
cached = await self._get_cached_token() |
|
|
now = time.time() |
|
|
if cached: |
|
|
expires_at = cached["expiry_date"] / 1000 |
|
|
if expires_at - now > self.TOKEN_BUFFER_TIME: |
|
|
self._access_token = cached["access_token"] |
|
|
self._expiry = expires_at |
|
|
return |
|
|
|
|
|
path = AuthManager.get_cache_file() |
|
|
if not path.exists(): |
|
|
path = get_oauth_creds_path() |
|
|
if path.exists(): |
|
|
try: |
|
|
with path.open("r") as f: |
|
|
creds = json.load(f) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to read OAuth credentials from {path}: {e}") |
|
|
else: |
|
|
|
|
|
if "GCP_SERVICE_ACCOUNT" not in self.env: |
|
|
raise RuntimeError("GCP_SERVICE_ACCOUNT environment variable not set.") |
|
|
creds = json.loads(self.env["GCP_SERVICE_ACCOUNT"]) |
|
|
|
|
|
refresh_token = creds.get("refresh_token") |
|
|
access_token = creds.get("access_token") |
|
|
expiry_date = creds.get("expiry_date") |
|
|
|
|
|
|
|
|
if access_token and expiry_date: |
|
|
expires_at = expiry_date / 1000 |
|
|
if expires_at - now > self.TOKEN_BUFFER_TIME: |
|
|
self._access_token = access_token |
|
|
self._expiry = expires_at |
|
|
await self._cache_token(access_token, expiry_date) |
|
|
return |
|
|
|
|
|
|
|
|
if not refresh_token: |
|
|
raise RuntimeError("No refresh token found in GCP_SERVICE_ACCOUNT.") |
|
|
|
|
|
await self._refresh_and_cache_token(refresh_token) |
|
|
|
|
|
async def _refresh_and_cache_token(self, refresh_token: str) -> None: |
|
|
headers = {"Content-Type": "application/x-www-form-urlencoded"} |
|
|
data = { |
|
|
"client_id": self.OAUTH_CLIENT_ID, |
|
|
"client_secret": self.OAUTH_CLIENT_SECRET, |
|
|
"refresh_token": refresh_token, |
|
|
"grant_type": "refresh_token", |
|
|
} |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
async with session.post(self.OAUTH_REFRESH_URL, data=data, headers=headers) as resp: |
|
|
if resp.status != 200: |
|
|
text = await resp.text() |
|
|
raise RuntimeError(f"Token refresh failed: {text}") |
|
|
resp_data = await resp.json() |
|
|
access_token = resp_data.get("access_token") |
|
|
expires_in = resp_data.get("expires_in", 3600) |
|
|
|
|
|
if not access_token: |
|
|
raise RuntimeError("No access_token in refresh response.") |
|
|
|
|
|
self._access_token = access_token |
|
|
self._expiry = time.time() + expires_in |
|
|
|
|
|
expiry_date_ms = int(self._expiry * 1000) |
|
|
|
|
|
await self._cache_token(access_token, expiry_date_ms) |
|
|
|
|
|
async def _cache_token(self, access_token: str, expiry_date: int) -> None: |
|
|
|
|
|
token_data = { |
|
|
"access_token": access_token, |
|
|
"expiry_date": expiry_date, |
|
|
"cached_at": int(time.time() * 1000), |
|
|
} |
|
|
self._token_cache[self.KV_TOKEN_KEY] = token_data |
|
|
|
|
|
async def _get_cached_token(self) -> Optional[Dict[str, Any]]: |
|
|
|
|
|
cached = self._token_cache.get(self.KV_TOKEN_KEY) |
|
|
if cached: |
|
|
expires_at = cached["expiry_date"] / 1000 |
|
|
if expires_at - time.time() > self.TOKEN_BUFFER_TIME: |
|
|
return cached |
|
|
return None |
|
|
|
|
|
async def clear_token_cache(self) -> None: |
|
|
self._access_token = None |
|
|
self._expiry = None |
|
|
|
|
|
def get_access_token(self) -> Optional[str]: |
|
|
|
|
|
if ( |
|
|
self._access_token is not None |
|
|
and self._expiry is not None |
|
|
and self._expiry - time.time() > self.TOKEN_BUFFER_TIME |
|
|
): |
|
|
return self._access_token |
|
|
return None |
|
|
|
|
|
async def call_endpoint(self, method: str, body: Dict[str, Any], is_retry=False) -> Any: |
|
|
""" |
|
|
Call Google Code Assist API endpoint with JSON body. |
|
|
|
|
|
Automatically retries once on 401 Unauthorized by refreshing auth. |
|
|
""" |
|
|
if not self.get_access_token(): |
|
|
await self.initialize_auth() |
|
|
|
|
|
url = f"https://cloudcode-pa.googleapis.com/v1internal:{method}" |
|
|
headers = { |
|
|
"Content-Type": "application/json", |
|
|
"Authorization": f"Bearer {self.get_access_token()}", |
|
|
} |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
async with session.post(url, headers=headers, json=body) as resp: |
|
|
if resp.status == 401 and not is_retry: |
|
|
|
|
|
await self.clear_token_cache() |
|
|
await self.initialize_auth() |
|
|
return await self.call_endpoint(method, body, is_retry=True) |
|
|
elif not resp.ok: |
|
|
text = await resp.text() |
|
|
raise RuntimeError(f"API call failed with status {resp.status}: {text}") |
|
|
|
|
|
return await resp.json() |
|
|
|
|
|
class GeminiCLIProvider(): |
|
|
url = "https://cloud.google.com/code-assist" |
|
|
api_base = "https://cloudcode-pa.googleapis.com/v1internal" |
|
|
|
|
|
|
|
|
auth_manager: AuthManager |
|
|
env: dict |
|
|
|
|
|
def __init__(self, env: dict, auth_manager: AuthManager): |
|
|
self.env = env |
|
|
self.auth_manager = auth_manager |
|
|
|
|
|
|
|
|
self._project_id: Optional[str] = None |
|
|
|
|
|
async def discover_project_id(self) -> str: |
|
|
if self.env.get("GEMINI_PROJECT_ID"): |
|
|
return self.env["GEMINI_PROJECT_ID"] |
|
|
if self._project_id: |
|
|
return self._project_id |
|
|
|
|
|
try: |
|
|
load_response = await self.auth_manager.call_endpoint( |
|
|
"loadCodeAssist", |
|
|
{ |
|
|
"cloudaicompanionProject": "default-project", |
|
|
"metadata": {"duetProject": "default-project"}, |
|
|
}, |
|
|
) |
|
|
project = load_response.get("cloudaicompanionProject") |
|
|
if project: |
|
|
self._project_id = project |
|
|
return project |
|
|
raise RuntimeError( |
|
|
"Project ID discovery failed - set GEMINI_PROJECT_ID in environment." |
|
|
) |
|
|
except Exception as e: |
|
|
debug.error(f"Failed to discover project ID: {e}") |
|
|
raise RuntimeError( |
|
|
"Could not discover project ID. Ensure authentication or set GEMINI_PROJECT_ID." |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _messages_to_gemini_format(messages: list, media: MediaListType) -> Dict[str, Any]: |
|
|
format_messages = [] |
|
|
for msg in messages: |
|
|
|
|
|
role = "model" if msg["role"] == "assistant" else "user" |
|
|
|
|
|
|
|
|
if msg["role"] == "tool": |
|
|
parts = [ |
|
|
{ |
|
|
"functionResponse": { |
|
|
"name": msg.get("tool_call_id", "unknown_function"), |
|
|
"response": { |
|
|
"result": ( |
|
|
msg["content"] |
|
|
if isinstance(msg["content"], str) |
|
|
else json.dumps(msg["content"]) |
|
|
) |
|
|
}, |
|
|
} |
|
|
} |
|
|
], |
|
|
|
|
|
|
|
|
elif msg["role"] == "assistant" and msg.get("tool_calls"): |
|
|
parts = [] |
|
|
if isinstance(msg["content"], str) and msg["content"].strip(): |
|
|
parts.append({"text": msg["content"]}) |
|
|
for tool_call in msg["tool_calls"]: |
|
|
if tool_call.get("type") == "function": |
|
|
parts.append( |
|
|
{ |
|
|
"functionCall": { |
|
|
"name": tool_call["function"]["name"], |
|
|
"args": json.loads(tool_call["function"]["arguments"]), |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
elif isinstance(msg["content"], str): |
|
|
parts = [{"text": msg["content"]}] |
|
|
|
|
|
|
|
|
elif isinstance(msg["content"], list): |
|
|
for content in msg["content"]: |
|
|
ctype = content.get("type") |
|
|
if ctype == "text": |
|
|
parts.append({"text": content["text"]}) |
|
|
elif ctype == "image_url": |
|
|
image_url = content.get("image_url", {}).get("url") |
|
|
if not image_url: |
|
|
continue |
|
|
if image_url.startswith("data:"): |
|
|
|
|
|
prefix, b64data = image_url.split(",", 1) |
|
|
mime_type = prefix.split(":")[1].split(";")[0] |
|
|
parts.append({"inlineData": {"mimeType": mime_type, "data": b64data}}) |
|
|
else: |
|
|
parts.append( |
|
|
{ |
|
|
"fileData": { |
|
|
"mimeType": "image/jpeg", |
|
|
"fileUri": image_url, |
|
|
} |
|
|
} |
|
|
) |
|
|
else: |
|
|
parts = [{"text": str(msg["content"])}] |
|
|
format_messages.append({"role": role, "parts": parts}) |
|
|
if media: |
|
|
if not format_messages: |
|
|
format_messages.append({"role": "user", "parts": []}) |
|
|
for media_data, filename in media: |
|
|
if isinstance(media_data, str): |
|
|
if not filename: |
|
|
filename = media_data |
|
|
extension = filename.split(".")[-1].replace("jpg", "jpeg") |
|
|
format_messages[-1]["parts"].append( |
|
|
{ |
|
|
"fileData": { |
|
|
"mimeType": f"image/{extension}", |
|
|
"fileUri": image_url, |
|
|
} |
|
|
} |
|
|
) |
|
|
else: |
|
|
media_data = to_bytes(media_data) |
|
|
format_messages[-1]["parts"].append({ |
|
|
"inlineData": { |
|
|
"mimeType": is_data_an_media(media_data, filename), |
|
|
"data": base64.b64encode(media_data).decode() |
|
|
} |
|
|
}) |
|
|
return format_messages |
|
|
|
|
|
async def stream_content( |
|
|
self, |
|
|
model: str, |
|
|
messages: Messages, |
|
|
*, |
|
|
proxy: Optional[str] = None, |
|
|
thinking_budget: Optional[int] = None, |
|
|
tools: Optional[List[dict]] = None, |
|
|
tool_choice: Optional[str] = None, |
|
|
max_tokens: Optional[int] = None, |
|
|
temperature: Optional[float] = None, |
|
|
top_p: Optional[float] = None, |
|
|
stop: Optional[Union[str, List[str]]] = None, |
|
|
presence_penalty: Optional[float] = None, |
|
|
frequency_penalty: Optional[float] = None, |
|
|
seed: Optional[int] = None, |
|
|
response_format: Optional[Dict[str, Any]] = None, |
|
|
**kwargs |
|
|
) -> AsyncGenerator: |
|
|
await self.auth_manager.initialize_auth() |
|
|
|
|
|
project_id = await self.discover_project_id() |
|
|
|
|
|
|
|
|
contents = self._messages_to_gemini_format([m for m in messages if m["role"] not in ["developer", "system"]], media=kwargs.get("media", None)) |
|
|
system_prompt = get_system_prompt(messages) |
|
|
requestData = {} |
|
|
if system_prompt: |
|
|
requestData["system_instruction"] = {"parts": {"text": system_prompt}} |
|
|
|
|
|
|
|
|
req_body = { |
|
|
"model": model, |
|
|
"project": project_id, |
|
|
"request": { |
|
|
"contents": contents, |
|
|
"generationConfig": { |
|
|
"maxOutputTokens": max_tokens, |
|
|
"temperature": temperature, |
|
|
"topP": top_p, |
|
|
"stop": stop, |
|
|
"presencePenalty": presence_penalty, |
|
|
"frequencyPenalty": frequency_penalty, |
|
|
"seed": seed, |
|
|
"responseMimeType": None if response_format is None else ("application/json" if response_format.get("type") == "json_object" else None), |
|
|
"thinkingConfig": { |
|
|
"thinkingBudget": thinking_budget, |
|
|
"includeThoughts": True |
|
|
} if thinking_budget else None, |
|
|
}, |
|
|
"tools": tools or [], |
|
|
"toolConfig": { |
|
|
"functionCallingConfig": { |
|
|
"mode": tool_choice.upper(), |
|
|
"allowedFunctionNames": [tool["function"]["name"] for tool in tools] |
|
|
} |
|
|
} if tool_choice else None, |
|
|
**requestData |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def clean_none(d): |
|
|
if isinstance(d, dict): |
|
|
return {k: clean_none(v) for k, v in d.items() if v is not None} |
|
|
if isinstance(d, list): |
|
|
return [clean_none(x) for x in d if x is not None] |
|
|
return d |
|
|
|
|
|
req_body = clean_none(req_body) |
|
|
|
|
|
headers = { |
|
|
"Content-Type": "application/json", |
|
|
"Authorization": f"Bearer {self.auth_manager.get_access_token()}", |
|
|
} |
|
|
|
|
|
url = f"{self.api_base}:streamGenerateContent?alt=sse" |
|
|
|
|
|
|
|
|
async def parse_sse_stream(stream: aiohttp.StreamReader) -> AsyncGenerator[Dict[str, Any], None]: |
|
|
"""Parse SSE stream yielding parsed JSON objects""" |
|
|
buffer = "" |
|
|
object_buffer = "" |
|
|
|
|
|
async for chunk_bytes in stream.iter_any(): |
|
|
chunk = chunk_bytes.decode() |
|
|
buffer += chunk |
|
|
lines = buffer.split("\n") |
|
|
buffer = lines.pop() |
|
|
|
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if line == "": |
|
|
|
|
|
if object_buffer: |
|
|
try: |
|
|
yield json.loads(object_buffer) |
|
|
except Exception as e: |
|
|
debug.error(f"Error parsing SSE JSON: {e}") |
|
|
object_buffer = "" |
|
|
elif line.startswith("data: "): |
|
|
object_buffer += line[6:] |
|
|
|
|
|
|
|
|
if object_buffer: |
|
|
try: |
|
|
yield json.loads(object_buffer) |
|
|
except Exception as e: |
|
|
debug.error(f"Error parsing final SSE JSON: {e}") |
|
|
|
|
|
timeout = ClientTimeout(total=None) |
|
|
connector = get_connector(None, proxy) |
|
|
|
|
|
async with ClientSession(headers=headers, timeout=timeout, connector=connector) as session: |
|
|
async with session.post(url, json=req_body) as resp: |
|
|
if not resp.ok: |
|
|
if resp.status == 401: |
|
|
|
|
|
raise MissingAuthError(f"Unauthorized (401) from Gemini API") |
|
|
error_body = await resp.text() |
|
|
raise RuntimeError(f"Gemini API error {resp.status}: {error_body}") |
|
|
|
|
|
async for json_data in parse_sse_stream(resp.content): |
|
|
|
|
|
candidates = json_data.get("response", {}).get("candidates", []) |
|
|
usage_metadata = json_data.get("response", {}).get("usageMetadata", {}) |
|
|
|
|
|
if not candidates: |
|
|
continue |
|
|
|
|
|
candidate = candidates[0] |
|
|
content = candidate.get("content", {}) |
|
|
parts = content.get("parts", []) |
|
|
|
|
|
tool_calls = [] |
|
|
|
|
|
for part in parts: |
|
|
|
|
|
if part.get("thought") is True and "text" in part: |
|
|
yield Reasoning(part["text"]) |
|
|
|
|
|
|
|
|
elif "functionCall" in part: |
|
|
tool_calls.append(part["functionCall"]) |
|
|
|
|
|
|
|
|
elif "text" in part: |
|
|
yield part["text"] |
|
|
|
|
|
|
|
|
elif "inlineData" in part: |
|
|
|
|
|
async for media in save_response_media(part["inlineData"], format_media_prompt(messages)): |
|
|
yield media |
|
|
|
|
|
|
|
|
elif "fileData" in part: |
|
|
|
|
|
file_data = part["fileData"] |
|
|
yield ImageResponse(file_data.get("fileUri")) |
|
|
|
|
|
if tool_calls: |
|
|
yield ToolCalls(tool_calls) |
|
|
if usage_metadata: |
|
|
yield Usage( |
|
|
promptTokens=usage_metadata.get("promptTokenCount", 0), |
|
|
completionTokens=usage_metadata.get("candidatesTokenCount", 0), |
|
|
) |
|
|
|
|
|
class GeminiCLI(AsyncGeneratorProvider, ProviderModelMixin): |
|
|
label = "Google Gemini CLI" |
|
|
login_url = "https://github.com/GewoonJaap/gemini-cli-openai" |
|
|
|
|
|
default_model = "gemini-2.5-pro" |
|
|
models = [ |
|
|
"gemini-2.5-pro", |
|
|
"gemini-2.5-flash", |
|
|
] |
|
|
|
|
|
working = True |
|
|
supports_message_history = True |
|
|
supports_system_message = True |
|
|
needs_auth = True |
|
|
active_by_default = True |
|
|
|
|
|
auth_manager: AuthManager = None |
|
|
|
|
|
@classmethod |
|
|
def get_models(cls, **kwargs): |
|
|
if cls.live == 0: |
|
|
if cls.auth_manager is None: |
|
|
cls.auth_manager = AuthManager(env=os.environ) |
|
|
if cls.auth_manager.get_access_token() is not None: |
|
|
cls.live += 1 |
|
|
return cls.models |
|
|
|
|
|
@classmethod |
|
|
async def create_async_generator( |
|
|
cls, |
|
|
model: str, |
|
|
messages: Messages, |
|
|
stream: bool = False, |
|
|
media: MediaListType = None, |
|
|
tools: Optional[list] = None, |
|
|
**kwargs |
|
|
) -> AsyncResult: |
|
|
if cls.auth_manager is None: |
|
|
cls.auth_manager = AuthManager(env=os.environ) |
|
|
|
|
|
|
|
|
provider = GeminiCLIProvider(env=os.environ, auth_manager=cls.auth_manager) |
|
|
|
|
|
async for chunk in provider.stream_content( |
|
|
model=model, |
|
|
messages=messages, |
|
|
stream=stream, |
|
|
media=media, |
|
|
tools=tools, |
|
|
**kwargs |
|
|
): |
|
|
yield chunk |