|
|
import os |
|
|
import json |
|
|
import time |
|
|
import asyncio |
|
|
import uuid |
|
|
from typing import Optional, Dict, Union |
|
|
from .stubs import IQwenOAuth2Client, ErrorDataDict |
|
|
from pathlib import Path |
|
|
import threading |
|
|
|
|
|
from ..base_provider import AuthFileMixin |
|
|
from ... import debug |
|
|
|
|
|
QWEN_DIR = ".qwen" |
|
|
QWEN_CREDENTIAL_FILENAME = "oauth_creds.json" |
|
|
QWEN_LOCK_FILENAME = "oauth_creds.lock" |
|
|
TOKEN_REFRESH_BUFFER_MS = 30 * 1000 |
|
|
LOCK_TIMEOUT_MS = 10000 |
|
|
CACHE_CHECK_INTERVAL_MS = 1000 |
|
|
|
|
|
|
|
|
def isErrorResponse( |
|
|
response: Union[Dict, ErrorDataDict] |
|
|
) -> bool: |
|
|
return "error" in response |
|
|
|
|
|
|
|
|
class TokenError: |
|
|
REFRESH_FAILED = "REFRESH_FAILED" |
|
|
NO_REFRESH_TOKEN = "NO_REFRESH_TOKEN" |
|
|
LOCK_TIMEOUT = "LOCK_TIMEOUT" |
|
|
FILE_ACCESS_ERROR = "FILE_ACCESS_ERROR" |
|
|
NETWORK_ERROR = "NETWORK_ERROR" |
|
|
|
|
|
|
|
|
class TokenManagerError(Exception): |
|
|
def __init__(self, type_: str, message: str, original_error: Optional[Exception] = None): |
|
|
super().__init__(message) |
|
|
self.type = type_ |
|
|
self.original_error = original_error |
|
|
|
|
|
|
|
|
class SharedTokenManager(AuthFileMixin): |
|
|
parent = "QwenCode" |
|
|
_instance: Optional["SharedTokenManager"] = None |
|
|
_lock = threading.Lock() |
|
|
|
|
|
def __init__(self): |
|
|
self.memory_cache = { |
|
|
"credentials": None, |
|
|
"file_mod_time": 0, |
|
|
"last_check": 0, |
|
|
} |
|
|
self.refresh_promise = None |
|
|
|
|
|
@classmethod |
|
|
def getInstance(cls): |
|
|
with cls._lock: |
|
|
if cls._instance is None: |
|
|
cls._instance = cls() |
|
|
return cls._instance |
|
|
|
|
|
def getCredentialFilePath(self): |
|
|
path = Path(os.path.expanduser(f"~/{QWEN_DIR}/{QWEN_CREDENTIAL_FILENAME}")) |
|
|
if path.is_file(): |
|
|
return path |
|
|
return SharedTokenManager.get_cache_file() |
|
|
|
|
|
def getLockFilePath(self): |
|
|
return Path(os.path.expanduser(f"~/{QWEN_DIR}/{QWEN_LOCK_FILENAME}")) |
|
|
|
|
|
def setLockConfig(self, config: dict): |
|
|
|
|
|
pass |
|
|
|
|
|
def registerCleanupHandlers(self): |
|
|
import atexit |
|
|
|
|
|
def cleanup(): |
|
|
try: |
|
|
lock_path = self.getLockFilePath() |
|
|
lock_path.unlink() |
|
|
except: |
|
|
pass |
|
|
|
|
|
atexit.register(cleanup) |
|
|
|
|
|
async def getValidCredentials(self, qwen_client: IQwenOAuth2Client, force_refresh: bool = False): |
|
|
try: |
|
|
self.checkAndReloadIfNeeded() |
|
|
|
|
|
if ( |
|
|
self.memory_cache["credentials"] |
|
|
and not force_refresh |
|
|
and self.isTokenValid(self.memory_cache["credentials"]) |
|
|
): |
|
|
return self.memory_cache["credentials"] |
|
|
|
|
|
if self.refresh_promise: |
|
|
return await self.refresh_promise |
|
|
|
|
|
self.refresh_promise = asyncio.create_task(self.performTokenRefresh(qwen_client, force_refresh)) |
|
|
credentials = await self.refresh_promise |
|
|
self.refresh_promise = None |
|
|
return credentials |
|
|
except Exception as e: |
|
|
if isinstance(e, TokenManagerError): |
|
|
raise |
|
|
raise TokenManagerError(TokenError.REFRESH_FAILED, str(e), e) from e |
|
|
|
|
|
def checkAndReloadIfNeeded(self): |
|
|
now = int(time.time() * 1000) |
|
|
if now - self.memory_cache["last_check"] < CACHE_CHECK_INTERVAL_MS: |
|
|
return |
|
|
self.memory_cache["last_check"] = now |
|
|
|
|
|
try: |
|
|
file_path = self.getCredentialFilePath() |
|
|
stat = file_path.stat() |
|
|
file_mod_time = int(stat.st_mtime * 1000) |
|
|
if file_mod_time > self.memory_cache["file_mod_time"]: |
|
|
self.reloadCredentialsFromFile() |
|
|
self.memory_cache["file_mod_time"] = file_mod_time |
|
|
except FileNotFoundError: |
|
|
self.memory_cache["file_mod_time"] = 0 |
|
|
except Exception as e: |
|
|
self.memory_cache["credentials"] = None |
|
|
raise TokenManagerError(TokenError.FILE_ACCESS_ERROR, str(e), e) |
|
|
|
|
|
def reloadCredentialsFromFile(self): |
|
|
file_path = self.getCredentialFilePath() |
|
|
debug.log(f"Reloading credentials from {file_path}") |
|
|
try: |
|
|
with open(file_path, "r") as fs: |
|
|
data = json.load(fs) |
|
|
credentials = self.validateCredentials(data) |
|
|
self.memory_cache["credentials"] = credentials |
|
|
except FileNotFoundError as e: |
|
|
self.memory_cache["credentials"] = None |
|
|
raise TokenManagerError(TokenError.FILE_ACCESS_ERROR, "Credentials file not found", e) from e |
|
|
except json.JSONDecodeError as e: |
|
|
self.memory_cache["credentials"] = None |
|
|
raise TokenManagerError(TokenError.FILE_ACCESS_ERROR, "Invalid JSON format", e) from e |
|
|
except Exception as e: |
|
|
self.memory_cache["credentials"] = None |
|
|
raise TokenManagerError(TokenError.FILE_ACCESS_ERROR, str(e), e) from e |
|
|
|
|
|
def validateCredentials(self, data): |
|
|
if not data or not isinstance(data, dict): |
|
|
raise ValueError("Invalid credentials format") |
|
|
for field in ["access_token", "refresh_token", "token_type"]: |
|
|
if field not in data or not isinstance(data[field], str): |
|
|
raise ValueError(f"Invalid credentials: missing {field}") |
|
|
if "expiry_date" not in data or not isinstance(data["expiry_date"], (int, float)): |
|
|
raise ValueError("Invalid credentials: missing expiry_date") |
|
|
return data |
|
|
|
|
|
async def performTokenRefresh(self, qwen_client: IQwenOAuth2Client, force_refresh: bool): |
|
|
lock_path = self.getLockFilePath() |
|
|
try: |
|
|
if self.memory_cache["credentials"] is None: |
|
|
self.reloadCredentialsFromFile() |
|
|
qwen_client.setCredentials(self.memory_cache["credentials"]) |
|
|
current_credentials = qwen_client.getCredentials() |
|
|
if not current_credentials.get("refresh_token"): |
|
|
raise TokenManagerError(TokenError.NO_REFRESH_TOKEN, "No refresh token") |
|
|
await self.acquireLock(lock_path) |
|
|
|
|
|
self.checkAndReloadIfNeeded() |
|
|
|
|
|
if ( |
|
|
not force_refresh |
|
|
and self.memory_cache["credentials"] |
|
|
and self.isTokenValid(self.memory_cache["credentials"]) |
|
|
): |
|
|
qwen_client.setCredentials(self.memory_cache["credentials"]) |
|
|
return self.memory_cache["credentials"] |
|
|
|
|
|
response = await qwen_client.refreshAccessToken() |
|
|
if not response or isErrorResponse(response): |
|
|
raise TokenManagerError(TokenError.REFRESH_FAILED, str(response)) |
|
|
token_data = response |
|
|
if "access_token" not in token_data: |
|
|
raise TokenManagerError(TokenError.REFRESH_FAILED, "No access_token returned") |
|
|
|
|
|
credentials = { |
|
|
"access_token": token_data["access_token"], |
|
|
"token_type": token_data["token_type"], |
|
|
"refresh_token": token_data.get("refresh_token", current_credentials.get("refresh_token")), |
|
|
"resource_url": token_data.get("resource_url"), |
|
|
"expiry_date": int(time.time() * 1000) + token_data.get("expires_in", 0) * 1000, |
|
|
} |
|
|
self.memory_cache["credentials"] = credentials |
|
|
qwen_client.setCredentials(credentials) |
|
|
|
|
|
await self.saveCredentialsToFile(credentials) |
|
|
return credentials |
|
|
except Exception as e: |
|
|
if isinstance(e, TokenManagerError): |
|
|
raise |
|
|
raise |
|
|
|
|
|
finally: |
|
|
await self.releaseLock(lock_path) |
|
|
|
|
|
async def acquireLock(self, lock_path: Path): |
|
|
max_attempts = 50 |
|
|
attempt_interval = 200 |
|
|
lock_id = str(uuid.uuid4()) |
|
|
os.makedirs(lock_path.parent, exist_ok=True) |
|
|
|
|
|
for _ in range(max_attempts): |
|
|
try: |
|
|
with open(lock_path, "w") as f: |
|
|
f.write(lock_id) |
|
|
return |
|
|
except: |
|
|
try: |
|
|
stat = os.stat(str(lock_path)) |
|
|
lock_age = int(time.time() * 1000) - int(stat.st_mtime * 1000) |
|
|
if lock_age > LOCK_TIMEOUT_MS: |
|
|
try: |
|
|
await os.unlink(str(lock_path)) |
|
|
except: |
|
|
pass |
|
|
except: |
|
|
pass |
|
|
await asyncio.sleep(attempt_interval / 1000) |
|
|
raise TokenManagerError(TokenError.LOCK_TIMEOUT, "Failed to acquire lock") |
|
|
|
|
|
async def releaseLock(self, lock_path: Path): |
|
|
try: |
|
|
await os.unlink(str(lock_path)) |
|
|
except: |
|
|
pass |
|
|
|
|
|
async def saveCredentialsToFile(self, credentials: dict): |
|
|
file_path = self.getCredentialFilePath() |
|
|
os.makedirs(file_path.parent, exist_ok=True) |
|
|
with open(file_path, "w") as f: |
|
|
f.write(json.dumps(credentials, indent=2)) |
|
|
stat = os.stat(str(file_path)) |
|
|
self.memory_cache["file_mod_time"] = int(stat.st_mtime * 1000) |
|
|
|
|
|
def isTokenValid(self, credentials: dict) -> bool: |
|
|
expiry_date = credentials.get("expiry_date") |
|
|
if not expiry_date: |
|
|
return False |
|
|
return time.time() * 1000 < expiry_date - TOKEN_REFRESH_BUFFER_MS |
|
|
|
|
|
def getCurrentCredentials(self): |
|
|
return self.memory_cache["credentials"] |
|
|
|