""" rate_limiter.py This module provides a RateLimiter class for managing API rate limits based on both request count and token usage. It's designed to work with OpenAI's API rate limit headers but can be adapted for other APIs with similar rate limiting mechanisms. The module includes: 1. A RateLimitInfo dataclass to store rate limit information. 2. A parse_time_string function to convert time strings to seconds using pytimeparse. 3. A RateLimiter class that manages and enforces rate limits. Dependencies: pytimeparse: Install with `pip install pytimeparse` Usage: rate_limiter = RateLimiter(initial_rate_requests=60, initial_rate_tokens=150000) await rate_limiter.acquire(estimated_tokens) # Make API call rate_limiter.update_rate_limit_info(response_headers) """ import asyncio import math import re import time from dataclasses import dataclass from httpx import Headers from pytimeparse import parse from vsp.shared import logger_factory logger = logger_factory.get_logger(__name__) @dataclass class OpenAIRateLimitInfo: """ Dataclass to store rate limit information for OpenAI API. Attributes: limit_requests (int): Maximum number of requests allowed. limit_tokens (int): Maximum number of tokens allowed. remaining_requests (int): Number of requests remaining. remaining_tokens (int): Number of tokens remaining. reset_requests_seconds (float): Time in seconds until the request limit resets. reset_tokens_seconds (float): Time in seconds until the token limit resets. """ limit_requests: int limit_tokens: int remaining_requests: int remaining_tokens: int reset_requests_seconds: float reset_tokens_seconds: float def parse_time_string(time_str: str) -> float: """ Parse a time string to seconds, handling various formats including milliseconds. This function handles time strings in various formats, including those returned in OpenAI's rate limit headers (e.g., '12h38m27.913s', '6m0s', '932ms'). Args: time_str (str): A string representing a duration. Returns: float: The total number of seconds represented by the time string. Raises: ValueError: If the time string cannot be parsed. Examples:: >>> parse_time_string('1h30m') 5400.0 >>> parse_time_string('45.5s') 45.5 >>> parse_time_string('1h23m45.6s') 5025.6 >>> parse_time_string('6m0s') 360.0 >>> parse_time_string('932ms') 0.932 """ # Check if the string is in milliseconds format ms_match = re.match(r"^(\d+)ms$", time_str) if ms_match: return float(ms_match.group(1)) / 1000 # Use pytimeparse for other formats seconds = parse(time_str) if seconds is None: raise ValueError(f"Could not parse time string: {time_str}") return float(seconds) class OpenAIRateLimiter: """ A class to manage and enforce rate limits for OpenAI API calls. This class handles both request-based and token-based rate limiting. It can adapt to changing rate limits based on information provided in API response headers. """ def __init__(self, initial_rate_requests: int, initial_rate_tokens: int, per: float = 60.0): """ Initialize the OpenAIRateLimiter. Args: initial_rate_requests (int): Initial number of requests allowed per minute. initial_rate_tokens (int): Initial number of tokens allowed per minute. per (float): The time period in seconds for which the rate applies. Defaults to 60.0. """ self.rate_requests = initial_rate_requests self.rate_tokens = initial_rate_tokens self.per = per self.allowance_requests = initial_rate_requests self.allowance_tokens = initial_rate_tokens self.last_check = time.time() self.rate_limit_info: OpenAIRateLimitInfo | None = None async def acquire(self, tokens_to_use: int) -> None: """ Acquire permission to make an API call, respecting rate limits. This method checks both request and token limits. If either limit is exceeded, it will pause execution for an appropriate amount of time. Args: tokens_to_use (int): The estimated number of tokens that will be used in this API call. Raises: asyncio.TimeoutError: If the sleep time exceeds the maximum allowed wait time. """ now = time.time() time_passed = now - self.last_check self.last_check = now if self.rate_limit_info: self.allowance_requests = self.rate_limit_info.remaining_requests self.allowance_tokens = self.rate_limit_info.remaining_tokens reset_time_requests_seconds = self.rate_limit_info.reset_requests_seconds reset_time_tokens_seconds = self.rate_limit_info.reset_tokens_seconds else: self.allowance_requests += int(time_passed * (self.rate_requests / self.per)) self.allowance_tokens += int(time_passed * (self.rate_tokens / self.per)) self.allowance_requests = min(self.allowance_requests, self.rate_requests) self.allowance_tokens = min(self.allowance_tokens, self.rate_tokens) # Default reset time if we don't have rate limit info reset_time_requests_seconds = reset_time_tokens_seconds = 30 wait_time = 0 if self.allowance_requests < 1: wait_time = int(max(wait_time, math.ceil(reset_time_requests_seconds))) if self.allowance_tokens < tokens_to_use: # If token reset time is more than a minute, it's likely the daily limit # In this case, we'll wait for the request reset time instead wait_time = int(max(wait_time, min(math.ceil(reset_time_tokens_seconds), 60))) if wait_time > 0: logger.info("Rate limit exceeded", sleep_time=wait_time) await asyncio.sleep(wait_time) self.allowance_requests = self.rate_requests self.allowance_tokens = self.rate_tokens else: self.allowance_requests -= 1 self.allowance_tokens -= tokens_to_use def update_rate_limit_info(self, headers: Headers) -> None: """ Update the rate limit information based on API response headers. This method should be called after each successful API call to keep the rate limit information up-to-date. Args: headers (Headers): The response headers from the API call. Note: This method expects headers to include 'x-ratelimit-*' keys as provided by the OpenAI API. If these headers are not present, it falls back to the initial or current values. """ self.rate_limit_info = OpenAIRateLimitInfo( limit_requests=int(headers.get("x-ratelimit-limit-requests", self.rate_requests)), limit_tokens=int(headers.get("x-ratelimit-limit-tokens", self.rate_tokens)), remaining_requests=int(headers.get("x-ratelimit-remaining-requests", self.allowance_requests)), remaining_tokens=int(headers.get("x-ratelimit-remaining-tokens", self.allowance_tokens)), reset_requests_seconds=parse_time_string(headers.get("x-ratelimit-reset-requests", "30s")), reset_tokens_seconds=parse_time_string(headers.get("x-ratelimit-reset-tokens", "30s")), )