vsp-demo / src /vsp /llm /openai /openai_rate_limiter.py
navkast
Update location of the VSP module (#1)
c1f8477 unverified
"""
rate_limiter.py
This module provides a RateLimiter class for managing API rate limits based on both request count and token usage.
It's designed to work with OpenAI's API rate limit headers but can be adapted for other APIs with similar rate limiting
mechanisms.
The module includes:
1. A RateLimitInfo dataclass to store rate limit information.
2. A parse_time_string function to convert time strings to seconds using pytimeparse.
3. A RateLimiter class that manages and enforces rate limits.
Dependencies:
pytimeparse: Install with `pip install pytimeparse`
Usage:
rate_limiter = RateLimiter(initial_rate_requests=60, initial_rate_tokens=150000)
await rate_limiter.acquire(estimated_tokens)
# Make API call
rate_limiter.update_rate_limit_info(response_headers)
"""
import asyncio
import math
import re
import time
from dataclasses import dataclass
from httpx import Headers
from pytimeparse import parse
from vsp.shared import logger_factory
logger = logger_factory.get_logger(__name__)
@dataclass
class OpenAIRateLimitInfo:
"""
Dataclass to store rate limit information for OpenAI API.
Attributes:
limit_requests (int): Maximum number of requests allowed.
limit_tokens (int): Maximum number of tokens allowed.
remaining_requests (int): Number of requests remaining.
remaining_tokens (int): Number of tokens remaining.
reset_requests_seconds (float): Time in seconds until the request limit resets.
reset_tokens_seconds (float): Time in seconds until the token limit resets.
"""
limit_requests: int
limit_tokens: int
remaining_requests: int
remaining_tokens: int
reset_requests_seconds: float
reset_tokens_seconds: float
def parse_time_string(time_str: str) -> float:
"""
Parse a time string to seconds, handling various formats including milliseconds.
This function handles time strings in various formats, including those
returned in OpenAI's rate limit headers (e.g., '12h38m27.913s', '6m0s', '932ms').
Args:
time_str (str): A string representing a duration.
Returns:
float: The total number of seconds represented by the time string.
Raises:
ValueError: If the time string cannot be parsed.
Examples::
>>> parse_time_string('1h30m')
5400.0
>>> parse_time_string('45.5s')
45.5
>>> parse_time_string('1h23m45.6s')
5025.6
>>> parse_time_string('6m0s')
360.0
>>> parse_time_string('932ms')
0.932
"""
# Check if the string is in milliseconds format
ms_match = re.match(r"^(\d+)ms$", time_str)
if ms_match:
return float(ms_match.group(1)) / 1000
# Use pytimeparse for other formats
seconds = parse(time_str)
if seconds is None:
raise ValueError(f"Could not parse time string: {time_str}")
return float(seconds)
class OpenAIRateLimiter:
"""
A class to manage and enforce rate limits for OpenAI API calls.
This class handles both request-based and token-based rate limiting.
It can adapt to changing rate limits based on information provided in API response headers.
"""
def __init__(self, initial_rate_requests: int, initial_rate_tokens: int, per: float = 60.0):
"""
Initialize the OpenAIRateLimiter.
Args:
initial_rate_requests (int): Initial number of requests allowed per minute.
initial_rate_tokens (int): Initial number of tokens allowed per minute.
per (float): The time period in seconds for which the rate applies. Defaults to 60.0.
"""
self.rate_requests = initial_rate_requests
self.rate_tokens = initial_rate_tokens
self.per = per
self.allowance_requests = initial_rate_requests
self.allowance_tokens = initial_rate_tokens
self.last_check = time.time()
self.rate_limit_info: OpenAIRateLimitInfo | None = None
async def acquire(self, tokens_to_use: int) -> None:
"""
Acquire permission to make an API call, respecting rate limits.
This method checks both request and token limits. If either limit is exceeded,
it will pause execution for an appropriate amount of time.
Args:
tokens_to_use (int): The estimated number of tokens that will be used in this API call.
Raises:
asyncio.TimeoutError: If the sleep time exceeds the maximum allowed wait time.
"""
now = time.time()
time_passed = now - self.last_check
self.last_check = now
if self.rate_limit_info:
self.allowance_requests = self.rate_limit_info.remaining_requests
self.allowance_tokens = self.rate_limit_info.remaining_tokens
reset_time_requests_seconds = self.rate_limit_info.reset_requests_seconds
reset_time_tokens_seconds = self.rate_limit_info.reset_tokens_seconds
else:
self.allowance_requests += int(time_passed * (self.rate_requests / self.per))
self.allowance_tokens += int(time_passed * (self.rate_tokens / self.per))
self.allowance_requests = min(self.allowance_requests, self.rate_requests)
self.allowance_tokens = min(self.allowance_tokens, self.rate_tokens)
# Default reset time if we don't have rate limit info
reset_time_requests_seconds = reset_time_tokens_seconds = 30
wait_time = 0
if self.allowance_requests < 1:
wait_time = int(max(wait_time, math.ceil(reset_time_requests_seconds)))
if self.allowance_tokens < tokens_to_use:
# If token reset time is more than a minute, it's likely the daily limit
# In this case, we'll wait for the request reset time instead
wait_time = int(max(wait_time, min(math.ceil(reset_time_tokens_seconds), 60)))
if wait_time > 0:
logger.info("Rate limit exceeded", sleep_time=wait_time)
await asyncio.sleep(wait_time)
self.allowance_requests = self.rate_requests
self.allowance_tokens = self.rate_tokens
else:
self.allowance_requests -= 1
self.allowance_tokens -= tokens_to_use
def update_rate_limit_info(self, headers: Headers) -> None:
"""
Update the rate limit information based on API response headers.
This method should be called after each successful API call to keep
the rate limit information up-to-date.
Args:
headers (Headers): The response headers from the API call.
Note:
This method expects headers to include 'x-ratelimit-*' keys as provided by the OpenAI API.
If these headers are not present, it falls back to the initial or current values.
"""
self.rate_limit_info = OpenAIRateLimitInfo(
limit_requests=int(headers.get("x-ratelimit-limit-requests", self.rate_requests)),
limit_tokens=int(headers.get("x-ratelimit-limit-tokens", self.rate_tokens)),
remaining_requests=int(headers.get("x-ratelimit-remaining-requests", self.allowance_requests)),
remaining_tokens=int(headers.get("x-ratelimit-remaining-tokens", self.allowance_tokens)),
reset_requests_seconds=parse_time_string(headers.get("x-ratelimit-reset-requests", "30s")),
reset_tokens_seconds=parse_time_string(headers.get("x-ratelimit-reset-tokens", "30s")),
)