File size: 22,407 Bytes
a4b70d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
import os
import json
import base64
import time
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Union

import aiohttp
from aiohttp import ClientSession, ClientTimeout

from ...typing import AsyncResult, Messages, MediaListType
from ...errors import MissingAuthError
from ...image.copy_images import save_response_media
from ...image import to_bytes, is_data_an_media
from ...providers.response import Usage, ImageResponse, ToolCalls, Reasoning
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin
from ..helper import get_connector, get_system_prompt, format_media_prompt
from ... import debug

def get_oauth_creds_path():
    return Path.home() / ".gemini" / "oauth_creds.json"

class AuthManager(AuthFileMixin):
    """
    Handles OAuth2 authentication and Google Code Assist API communication.
    Manages token caching, refresh, and API calls.

    Requires environment dict-like object with keys:
        - GCP_SERVICE_ACCOUNT: JSON string with OAuth2 credentials, containing:
            access_token, expiry_date (ms timestamp), refresh_token
        - Optionally supports cache storage via a KV storage interface implementing:
            get(key) -> value or None,
            put(key, value, expiration_seconds),
            delete(key)
    """
    parent = "GeminiCLI"

    OAUTH_REFRESH_URL = "https://oauth2.googleapis.com/token"
    OAUTH_CLIENT_ID = "681255809395" + "-oo8ft2oprdrnp9e3aqf6av3hmdib135j" + ".apps.googleusercontent.com"
    OAUTH_CLIENT_SECRET = "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"
    TOKEN_BUFFER_TIME = 5 * 60  # seconds, 5 minutes
    KV_TOKEN_KEY = "oauth_token_cache"

    def __init__(self, env: Dict[str, Any]):
        self.env = env
        self._access_token: Optional[str] = None
        self._expiry: Optional[float] = None  # Unix timestamp in seconds
        self._token_cache = {}  # Example in-memory cache; replace with KV store for production

    async def initialize_auth(self) -> None:
        """
        Initialize authentication by using cached token, or refreshing if needed.
        Raises RuntimeError if no valid token can be obtained.
        """
        # Try cached token from KV store or in-memory cache
        cached = await self._get_cached_token()
        now = time.time()
        if cached:
            expires_at = cached["expiry_date"] / 1000  # ms to seconds
            if expires_at - now > self.TOKEN_BUFFER_TIME:
                self._access_token = cached["access_token"]
                self._expiry = expires_at
                return  # Use cached token if valid

        path = AuthManager.get_cache_file()
        if not path.exists():
            path = get_oauth_creds_path()
        if path.exists():
            try:
                with path.open("r") as f:
                    creds = json.load(f)
            except Exception as e:
                raise RuntimeError(f"Failed to read OAuth credentials from {path}: {e}")
        else:
            # Parse credentials from environment
            if "GCP_SERVICE_ACCOUNT" not in self.env:
                raise RuntimeError("GCP_SERVICE_ACCOUNT environment variable not set.")
            creds = json.loads(self.env["GCP_SERVICE_ACCOUNT"])

        refresh_token = creds.get("refresh_token")
        access_token = creds.get("access_token")
        expiry_date = creds.get("expiry_date")  # milliseconds since epoch

        # Use original access token if still valid
        if access_token and expiry_date:
            expires_at = expiry_date / 1000
            if expires_at - now > self.TOKEN_BUFFER_TIME:
                self._access_token = access_token
                self._expiry = expires_at
                await self._cache_token(access_token, expiry_date)
                return

        # Otherwise, refresh token
        if not refresh_token:
            raise RuntimeError("No refresh token found in GCP_SERVICE_ACCOUNT.")

        await self._refresh_and_cache_token(refresh_token)

    async def _refresh_and_cache_token(self, refresh_token: str) -> None:
        headers = {"Content-Type": "application/x-www-form-urlencoded"}
        data = {
            "client_id": self.OAUTH_CLIENT_ID,
            "client_secret": self.OAUTH_CLIENT_SECRET,
            "refresh_token": refresh_token,
            "grant_type": "refresh_token",
        }

        async with aiohttp.ClientSession() as session:
            async with session.post(self.OAUTH_REFRESH_URL, data=data, headers=headers) as resp:
                if resp.status != 200:
                    text = await resp.text()
                    raise RuntimeError(f"Token refresh failed: {text}")
                resp_data = await resp.json()
                access_token = resp_data.get("access_token")
                expires_in = resp_data.get("expires_in", 3600)  # seconds

                if not access_token:
                    raise RuntimeError("No access_token in refresh response.")

                self._access_token = access_token
                self._expiry = time.time() + expires_in

                expiry_date_ms = int(self._expiry * 1000)  # milliseconds

                await self._cache_token(access_token, expiry_date_ms)

    async def _cache_token(self, access_token: str, expiry_date: int) -> None:
        # Cache token in KV store or fallback to memory cache
        token_data = {
            "access_token": access_token,
            "expiry_date": expiry_date,
            "cached_at": int(time.time() * 1000),  # ms
        }
        self._token_cache[self.KV_TOKEN_KEY] = token_data

    async def _get_cached_token(self) -> Optional[Dict[str, Any]]:
        # Return in-memory cached token if present and still valid
        cached = self._token_cache.get(self.KV_TOKEN_KEY)
        if cached:
            expires_at = cached["expiry_date"] / 1000
            if expires_at - time.time() > self.TOKEN_BUFFER_TIME:
                return cached
        return None

    async def clear_token_cache(self) -> None:
        self._access_token = None
        self._expiry = None

    def get_access_token(self) -> Optional[str]:
        # Return current valid access token or None
        if (
            self._access_token is not None
            and self._expiry is not None
            and self._expiry - time.time() > self.TOKEN_BUFFER_TIME
        ):
            return self._access_token
        return None

    async def call_endpoint(self, method: str, body: Dict[str, Any], is_retry=False) -> Any:
        """
        Call Google Code Assist API endpoint with JSON body.

        Automatically retries once on 401 Unauthorized by refreshing auth.
        """
        if not self.get_access_token():
            await self.initialize_auth()

        url = f"https://cloudcode-pa.googleapis.com/v1internal:{method}"
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.get_access_token()}",
        }

        async with aiohttp.ClientSession() as session:
            async with session.post(url, headers=headers, json=body) as resp:
                if resp.status == 401 and not is_retry:
                    # Token likely expired, clear and retry once
                    await self.clear_token_cache()
                    await self.initialize_auth()
                    return await self.call_endpoint(method, body, is_retry=True)
                elif not resp.ok:
                    text = await resp.text()
                    raise RuntimeError(f"API call failed with status {resp.status}: {text}")

                return await resp.json()

class GeminiCLIProvider():
    url = "https://cloud.google.com/code-assist"
    api_base = "https://cloudcode-pa.googleapis.com/v1internal"

    # Required for authentication and token management; Expects a compatible AuthManager instance
    auth_manager: AuthManager
    env: dict

    def __init__(self, env: dict, auth_manager: AuthManager):
        self.env = env
        self.auth_manager = auth_manager

        # Cache for discovered project ID
        self._project_id: Optional[str] = None

    async def discover_project_id(self) -> str:
        if self.env.get("GEMINI_PROJECT_ID"):
            return self.env["GEMINI_PROJECT_ID"]
        if self._project_id:
            return self._project_id

        try:
            load_response = await self.auth_manager.call_endpoint(
                "loadCodeAssist",
                {
                    "cloudaicompanionProject": "default-project",
                    "metadata": {"duetProject": "default-project"},
                },
            )
            project = load_response.get("cloudaicompanionProject")
            if project:
                self._project_id = project
                return project
            raise RuntimeError(
                "Project ID discovery failed - set GEMINI_PROJECT_ID in environment."
            )
        except Exception as e:
            debug.error(f"Failed to discover project ID: {e}")
            raise RuntimeError(
                "Could not discover project ID. Ensure authentication or set GEMINI_PROJECT_ID."
            )

    @staticmethod
    def _messages_to_gemini_format(messages: list, media: MediaListType) -> Dict[str, Any]:
        format_messages = []
        for msg in messages:
            # Convert a ChatMessage dict to GeminiFormattedMessage dict
            role = "model" if msg["role"] == "assistant" else "user"

            # Handle tool role (OpenAI style)
            if msg["role"] == "tool":
                parts = [
                    {
                        "functionResponse": {
                            "name": msg.get("tool_call_id", "unknown_function"),
                            "response": {
                                "result": (
                                    msg["content"]
                                    if isinstance(msg["content"], str)
                                    else json.dumps(msg["content"])
                                )
                            },
                        }
                    }
                ],

            # Handle assistant messages with tool calls
            elif msg["role"] == "assistant" and msg.get("tool_calls"):
                parts = []
                if isinstance(msg["content"], str) and msg["content"].strip():
                    parts.append({"text": msg["content"]})
                for tool_call in msg["tool_calls"]:
                    if tool_call.get("type") == "function":
                        parts.append(
                            {
                                "functionCall": {
                                    "name": tool_call["function"]["name"],
                                    "args": json.loads(tool_call["function"]["arguments"]),
                                }
                            }
                        )

            # Handle string content
            elif isinstance(msg["content"], str):
                parts = [{"text": msg["content"]}]

            # Handle array content (possibly multimodal)
            elif isinstance(msg["content"], list):
                for content in msg["content"]:
                    ctype = content.get("type")
                    if ctype == "text":
                        parts.append({"text": content["text"]})
                    elif ctype == "image_url":
                        image_url = content.get("image_url", {}).get("url")
                        if not image_url:
                            continue
                        if image_url.startswith("data:"):
                            # Inline base64 data image
                            prefix, b64data = image_url.split(",", 1)
                            mime_type = prefix.split(":")[1].split(";")[0]
                            parts.append({"inlineData": {"mimeType": mime_type, "data": b64data}})
                        else:
                            parts.append(
                                {
                                    "fileData": {
                                        "mimeType": "image/jpeg",  # Could improve by validation
                                        "fileUri": image_url,
                                    }
                                }
                            )
            else:
                parts = [{"text": str(msg["content"])}]
            format_messages.append({"role": role, "parts": parts})
        if media:
            if not format_messages:
                format_messages.append({"role": "user", "parts": []})
            for media_data, filename in media:
                if isinstance(media_data, str):
                    if not filename:
                        filename = media_data
                    extension = filename.split(".")[-1].replace("jpg", "jpeg")
                    format_messages[-1]["parts"].append(
                        {
                            "fileData": {
                                "mimeType": f"image/{extension}",
                                "fileUri": image_url,
                            }
                        }
                    )
                else:
                    media_data = to_bytes(media_data)
                    format_messages[-1]["parts"].append({
                        "inlineData": {
                            "mimeType": is_data_an_media(media_data, filename),
                            "data": base64.b64encode(media_data).decode()
                        }
                    })
        return format_messages
    
    async def stream_content(
        self,
        model: str,
        messages: Messages,
        *,
        proxy: Optional[str] = None,
        thinking_budget: Optional[int] = None,
        tools: Optional[List[dict]] = None,
        tool_choice: Optional[str] = None,
        max_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        stop: Optional[Union[str, List[str]]] = None,
        presence_penalty: Optional[float] = None,
        frequency_penalty: Optional[float] = None,
        seed: Optional[int] = None,
        response_format: Optional[Dict[str, Any]] = None,
        **kwargs
    ) -> AsyncGenerator:
        await self.auth_manager.initialize_auth()

        project_id = await self.discover_project_id()

        # Convert messages to Gemini format
        contents = self._messages_to_gemini_format([m for m in messages if m["role"] not in ["developer", "system"]], media=kwargs.get("media", None))
        system_prompt = get_system_prompt(messages)
        requestData = {}
        if system_prompt:
            requestData["system_instruction"] = {"parts": {"text": system_prompt}}

        # Compose request body
        req_body = {
            "model": model,
            "project": project_id,
            "request": {
                "contents": contents,
                "generationConfig": {
                    "maxOutputTokens": max_tokens,
                    "temperature": temperature,
                    "topP": top_p,
                    "stop": stop,
                    "presencePenalty": presence_penalty,
                    "frequencyPenalty": frequency_penalty,
                    "seed": seed,
                    "responseMimeType": None if response_format is None else ("application/json" if response_format.get("type") == "json_object" else None),
                    "thinkingConfig": {
                        "thinkingBudget": thinking_budget,
                        "includeThoughts": True
                    } if thinking_budget else None,
                },
                "tools": tools or [],
                "toolConfig": {
                    "functionCallingConfig": {
                        "mode": tool_choice.upper(),
                        "allowedFunctionNames": [tool["function"]["name"] for tool in tools]
                    }
                } if tool_choice else None,
                **requestData
            },
        }

        # Remove None values recursively
        def clean_none(d):
            if isinstance(d, dict):
                return {k: clean_none(v) for k, v in d.items() if v is not None}
            if isinstance(d, list):
                return [clean_none(x) for x in d if x is not None]
            return d

        req_body = clean_none(req_body)

        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.auth_manager.get_access_token()}",
        }

        url = f"{self.api_base}:streamGenerateContent?alt=sse"

        # Streaming SSE parsing helper
        async def parse_sse_stream(stream: aiohttp.StreamReader) -> AsyncGenerator[Dict[str, Any], None]:
            """Parse SSE stream yielding parsed JSON objects"""
            buffer = ""
            object_buffer = ""

            async for chunk_bytes in stream.iter_any():
                chunk = chunk_bytes.decode()
                buffer += chunk
                lines = buffer.split("\n")
                buffer = lines.pop()  # Save last incomplete line back

                for line in lines:
                    line = line.strip()
                    if line == "":
                        # Empty line indicates end of SSE message -> parse object buffer
                        if object_buffer:
                            try:
                                yield json.loads(object_buffer)
                            except Exception as e:
                                debug.error(f"Error parsing SSE JSON: {e}")
                            object_buffer = ""
                    elif line.startswith("data: "):
                        object_buffer += line[6:]

            # Final parse when stream ends
            if object_buffer:
                try:
                    yield json.loads(object_buffer)
                except Exception as e:
                    debug.error(f"Error parsing final SSE JSON: {e}")

        timeout = ClientTimeout(total=None)  # No total timeout
        connector = get_connector(None, proxy)  # Customize connector as needed (supports proxy)

        async with ClientSession(headers=headers, timeout=timeout, connector=connector) as session:
            async with session.post(url, json=req_body) as resp:
                if not resp.ok:
                    if resp.status == 401:
                        # Possibly token expired: try login retry logic, omitted here for brevity
                        raise MissingAuthError(f"Unauthorized (401) from Gemini API")
                    error_body = await resp.text()
                    raise RuntimeError(f"Gemini API error {resp.status}: {error_body}")

                async for json_data in parse_sse_stream(resp.content):
                    # Process JSON data according to Gemini API structure
                    candidates = json_data.get("response", {}).get("candidates", [])
                    usage_metadata = json_data.get("response", {}).get("usageMetadata", {})

                    if not candidates:
                        continue

                    candidate = candidates[0]
                    content = candidate.get("content", {})
                    parts = content.get("parts", [])

                    tool_calls = []

                    for part in parts:
                        # Real thinking chunks
                        if part.get("thought") is True and "text" in part:
                            yield Reasoning(part["text"])

                        # Function calls from Gemini
                        elif "functionCall" in part:
                            tool_calls.append(part["functionCall"])

                        # Text content
                        elif "text" in part:
                            yield part["text"]

                        # Inline media data
                        elif "inlineData" in part:
                            # Media chunk - yield media asynchronously
                            async for media in save_response_media(part["inlineData"], format_media_prompt(messages)):
                                yield media

                        # File data (e.g. external image)
                        elif "fileData" in part:
                            # Just yield the file URI for now
                            file_data = part["fileData"]
                            yield ImageResponse(file_data.get("fileUri"))

                    if tool_calls:
                        yield ToolCalls(tool_calls)
                    if usage_metadata:
                        yield Usage(
                            promptTokens=usage_metadata.get("promptTokenCount", 0),
                            completionTokens=usage_metadata.get("candidatesTokenCount", 0),
                        )

class GeminiCLI(AsyncGeneratorProvider, ProviderModelMixin):
    label = "Google Gemini CLI"
    login_url = "https://github.com/GewoonJaap/gemini-cli-openai"

    default_model = "gemini-2.5-pro"
    models = [
        "gemini-2.5-pro",
        "gemini-2.5-flash",
    ]

    working = True
    supports_message_history = True
    supports_system_message = True
    needs_auth = True
    active_by_default = True

    auth_manager: AuthManager = None

    @classmethod
    def get_models(cls, **kwargs):
        if cls.live == 0:
            if cls.auth_manager is None:
                cls.auth_manager = AuthManager(env=os.environ)
            if cls.auth_manager.get_access_token() is not None:
                cls.live += 1
        return cls.models

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        stream: bool = False,
        media: MediaListType = None,
        tools: Optional[list] = None,
        **kwargs
    ) -> AsyncResult:
        if cls.auth_manager is None:
            cls.auth_manager = AuthManager(env=os.environ)

        # Initialize Gemini CLI provider with auth manager and environment
        provider = GeminiCLIProvider(env=os.environ, auth_manager=cls.auth_manager)

        async for chunk in provider.stream_content(
            model=model,
            messages=messages,
            stream=stream,
            media=media,
            tools=tools,
            **kwargs
        ):
            yield chunk