mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
feat(rate limit aware acompletion calls):
This commit is contained in:
parent
035c65ed4a
commit
228d6ea608
1 changed files with 108 additions and 18 deletions
126
litellm/utils.py
126
litellm/utils.py
|
@ -3800,17 +3800,20 @@ class APIRequest:
|
|||
self,
|
||||
request_header: dict,
|
||||
retry_queue: asyncio.Queue,
|
||||
save_filepath: str,
|
||||
status_tracker: StatusTracker,
|
||||
save_filepath: str = "",
|
||||
):
|
||||
"""Calls the OpenAI API and saves results."""
|
||||
logging.info(f"Making API Call for request #{self.task_id}")
|
||||
logging.info(f"Making API Call for request #{self.task_id} {self.request_json}")
|
||||
error = None
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
**self.request_json
|
||||
)
|
||||
logging.info(f"Completed request #{self.task_id}")
|
||||
if save_filepath == "": # return respons
|
||||
return response
|
||||
# else this gets written to save_filepath
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Request {self.task_id} failed with error {e}"
|
||||
|
@ -3861,31 +3864,118 @@ class APIRequest:
|
|||
|
||||
|
||||
class RateLimitManager():
|
||||
import uuid
|
||||
def __init__(self, max_tokens_per_minute, max_requests_per_minute):
|
||||
self.max_tokens_per_minute = max_tokens_per_minute
|
||||
self.max_requests_per_minute = max_requests_per_minute
|
||||
# print("init rate limit handler")
|
||||
self.status_tracker = StatusTracker()
|
||||
self.last_update_time = time.time()
|
||||
self.available_request_capacity = max_requests_per_minute
|
||||
self.available_token_capacity = max_tokens_per_minute
|
||||
self.queue_of_requests_to_retry = asyncio.Queue() # type: ignore
|
||||
self.task = 0 # for tracking ids for tasks
|
||||
self.cooldown_time = 10 # time to cooldown between retries in seconds
|
||||
|
||||
# async def acompletion(self, max_attempts=5, kwargs):
|
||||
# # init request
|
||||
# request = APIRequest(
|
||||
# task_id=next(self.task_id_generator_function()),
|
||||
# request_json=kwargs,
|
||||
# token_consumption=self.num_tokens_consumed_from_request(
|
||||
# request_json, token_encoding_name
|
||||
# ),
|
||||
# attempts_left=max_attempts,
|
||||
# metadata=request_json.pop("metadata", None),
|
||||
# )
|
||||
async def acompletion(self, max_attempts=5, **kwargs):
|
||||
# Initialize logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# # check current capacity for model
|
||||
# Initialize request
|
||||
logging.info(f"Initializing API request for request id:{self.task}")
|
||||
request = APIRequest(
|
||||
task_id=self.task,
|
||||
request_json=kwargs,
|
||||
token_consumption=self.num_tokens_consumed_from_request(request_json=kwargs, token_encoding_name="cl100k_base"),
|
||||
attempts_left=max_attempts,
|
||||
metadata=kwargs.pop("metadata", None),
|
||||
)
|
||||
self.task+=1 # added a new task to execute
|
||||
|
||||
# # if under capacity
|
||||
# # check if fallback model specified
|
||||
# Check and update current capacity for model
|
||||
current_time = time.time()
|
||||
seconds_since_update = current_time - self.last_update_time
|
||||
|
||||
self.available_request_capacity = min(
|
||||
self.available_request_capacity + self.max_requests_per_minute * seconds_since_update / 60.0,
|
||||
self.max_requests_per_minute,
|
||||
)
|
||||
|
||||
# # if no fallback model specified then wait to process request
|
||||
|
||||
self.available_token_capacity = min(
|
||||
self.available_token_capacity + self.max_tokens_per_minute * seconds_since_update / 60.0,
|
||||
self.max_tokens_per_minute,
|
||||
)
|
||||
|
||||
self.last_update_time = current_time
|
||||
|
||||
request_tokens = request.token_consumption
|
||||
logging.debug("Request tokens: " + str(request_tokens))
|
||||
|
||||
queue_of_requests_to_retry = asyncio.Queue()
|
||||
|
||||
if (self.available_request_capacity >= 1 and self.available_token_capacity >= request_tokens):
|
||||
|
||||
# Update counters
|
||||
self.available_request_capacity -= 1
|
||||
self.available_token_capacity -= request_tokens
|
||||
request.attempts_left -= 1
|
||||
|
||||
# Call API and log final status
|
||||
logging.info(f"""Running Request {request.task_id}, using tokens: {request.token_consumption}. Remaining available tokens: {self.available_token_capacity}""")
|
||||
|
||||
result = await request.call_api(
|
||||
request_header={},
|
||||
retry_queue=queue_of_requests_to_retry,
|
||||
save_filepath="",
|
||||
status_tracker=self.status_tracker,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
logging.info(f"OVER CAPACITY for {request.task_id}. retrying {request.attempts_left} times")
|
||||
while request.attempts_left >= 0:
|
||||
# Sleep for a minute to allow for capacity
|
||||
logging.info(f"OVER CAPACITY for {request.task_id}. Cooling down for 60 seconds, retrying {request.attempts_left} times")
|
||||
await asyncio.sleep(self.cooldown_time)
|
||||
|
||||
# Check capacity
|
||||
current_time = time.time()
|
||||
seconds_since_update = current_time - self.last_update_time
|
||||
|
||||
self.available_request_capacity = min(
|
||||
self.available_request_capacity + self.max_requests_per_minute * seconds_since_update / 60.0,
|
||||
self.max_requests_per_minute,
|
||||
)
|
||||
|
||||
self.available_token_capacity = min(
|
||||
self.available_token_capacity + self.max_tokens_per_minute * seconds_since_update / 60.0,
|
||||
self.max_tokens_per_minute,
|
||||
)
|
||||
|
||||
self.last_update_time = current_time
|
||||
|
||||
request_tokens = request.token_consumption
|
||||
|
||||
if self.available_request_capacity >= 1 and self.available_token_capacity >= request_tokens:
|
||||
logging.info("Available token capacity available.")
|
||||
|
||||
# Update counters
|
||||
self.available_request_capacity -= 1
|
||||
self.available_token_capacity -= request_tokens
|
||||
request.attempts_left -= 1
|
||||
|
||||
# Call API and log final status
|
||||
logging.info(f"""Running Request {request.task_id}, using tokens: {request.token_consumption}. Remaining available tokens: {self.available_token_capacity}""")
|
||||
|
||||
result = await request.call_api(
|
||||
request_header={},
|
||||
retry_queue=queue_of_requests_to_retry,
|
||||
save_filepath="",
|
||||
status_tracker=self.status_tracker,
|
||||
)
|
||||
return result
|
||||
|
||||
logging.warning(f"Request {request.task_id} is still over capacity. Number of retry attempts left: {request.attempts_left}")
|
||||
request.attempts_left -=1
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue