feat(rate limit aware acompletion calls):

This commit is contained in:
ishaan-jaff 2023-10-06 20:40:12 -07:00
parent 035c65ed4a
commit 228d6ea608

View file

@ -3800,17 +3800,20 @@ class APIRequest:
self,
request_header: dict,
retry_queue: asyncio.Queue,
save_filepath: str,
status_tracker: StatusTracker,
save_filepath: str = "",
):
"""Calls the OpenAI API and saves results."""
logging.info(f"Making API Call for request #{self.task_id}")
logging.info(f"Making API Call for request #{self.task_id} {self.request_json}")
error = None
try:
response = await litellm.acompletion(
**self.request_json
)
logging.info(f"Completed request #{self.task_id}")
if save_filepath == "": # return respons
return response
# else this gets written to save_filepath
except Exception as e:
logging.warning(
f"Request {self.task_id} failed with error {e}"
@ -3861,31 +3864,118 @@ class APIRequest:
class RateLimitManager():
import uuid
def __init__(self, max_tokens_per_minute, max_requests_per_minute):
self.max_tokens_per_minute = max_tokens_per_minute
self.max_requests_per_minute = max_requests_per_minute
# print("init rate limit handler")
self.status_tracker = StatusTracker()
self.last_update_time = time.time()
self.available_request_capacity = max_requests_per_minute
self.available_token_capacity = max_tokens_per_minute
self.queue_of_requests_to_retry = asyncio.Queue() # type: ignore
self.task = 0 # for tracking ids for tasks
self.cooldown_time = 10 # time to cooldown between retries in seconds
# async def acompletion(self, max_attempts=5, kwargs):
# # init request
# request = APIRequest(
# task_id=next(self.task_id_generator_function()),
# request_json=kwargs,
# token_consumption=self.num_tokens_consumed_from_request(
# request_json, token_encoding_name
# ),
# attempts_left=max_attempts,
# metadata=request_json.pop("metadata", None),
# )
async def acompletion(self, max_attempts=5, **kwargs):
# Initialize logging
logging.basicConfig(level=logging.INFO)
# # check current capacity for model
# Initialize request
logging.info(f"Initializing API request for request id:{self.task}")
request = APIRequest(
task_id=self.task,
request_json=kwargs,
token_consumption=self.num_tokens_consumed_from_request(request_json=kwargs, token_encoding_name="cl100k_base"),
attempts_left=max_attempts,
metadata=kwargs.pop("metadata", None),
)
self.task+=1 # added a new task to execute
# # if under capacity
# # check if fallback model specified
# Check and update current capacity for model
current_time = time.time()
seconds_since_update = current_time - self.last_update_time
self.available_request_capacity = min(
self.available_request_capacity + self.max_requests_per_minute * seconds_since_update / 60.0,
self.max_requests_per_minute,
)
# # if no fallback model specified then wait to process request
self.available_token_capacity = min(
self.available_token_capacity + self.max_tokens_per_minute * seconds_since_update / 60.0,
self.max_tokens_per_minute,
)
self.last_update_time = current_time
request_tokens = request.token_consumption
logging.debug("Request tokens: " + str(request_tokens))
queue_of_requests_to_retry = asyncio.Queue()
if (self.available_request_capacity >= 1 and self.available_token_capacity >= request_tokens):
# Update counters
self.available_request_capacity -= 1
self.available_token_capacity -= request_tokens
request.attempts_left -= 1
# Call API and log final status
logging.info(f"""Running Request {request.task_id}, using tokens: {request.token_consumption}. Remaining available tokens: {self.available_token_capacity}""")
result = await request.call_api(
request_header={},
retry_queue=queue_of_requests_to_retry,
save_filepath="",
status_tracker=self.status_tracker,
)
return result
else:
logging.info(f"OVER CAPACITY for {request.task_id}. retrying {request.attempts_left} times")
while request.attempts_left >= 0:
# Sleep for a minute to allow for capacity
logging.info(f"OVER CAPACITY for {request.task_id}. Cooling down for 60 seconds, retrying {request.attempts_left} times")
await asyncio.sleep(self.cooldown_time)
# Check capacity
current_time = time.time()
seconds_since_update = current_time - self.last_update_time
self.available_request_capacity = min(
self.available_request_capacity + self.max_requests_per_minute * seconds_since_update / 60.0,
self.max_requests_per_minute,
)
self.available_token_capacity = min(
self.available_token_capacity + self.max_tokens_per_minute * seconds_since_update / 60.0,
self.max_tokens_per_minute,
)
self.last_update_time = current_time
request_tokens = request.token_consumption
if self.available_request_capacity >= 1 and self.available_token_capacity >= request_tokens:
logging.info("Available token capacity available.")
# Update counters
self.available_request_capacity -= 1
self.available_token_capacity -= request_tokens
request.attempts_left -= 1
# Call API and log final status
logging.info(f"""Running Request {request.task_id}, using tokens: {request.token_consumption}. Remaining available tokens: {self.available_token_capacity}""")
result = await request.call_api(
request_header={},
retry_queue=queue_of_requests_to_retry,
save_filepath="",
status_tracker=self.status_tracker,
)
return result
logging.warning(f"Request {request.task_id} is still over capacity. Number of retry attempts left: {request.attempts_left}")
request.attempts_left -=1
async def batch_completion(
self,