feat(dynamic_rate_limiter.py): passing base case

This commit is contained in:
Krrish Dholakia 2024-06-21 22:46:46 -07:00
parent a028600932
commit 068e8dff5b
5 changed files with 310 additions and 12 deletions

View file

@ -428,7 +428,7 @@ def mock_completion(
model: str, model: str,
messages: List, messages: List,
stream: Optional[bool] = False, stream: Optional[bool] = False,
mock_response: Union[str, Exception] = "This is a mock request", mock_response: Union[str, Exception, dict] = "This is a mock request",
mock_tool_calls: Optional[List] = None, mock_tool_calls: Optional[List] = None,
logging=None, logging=None,
custom_llm_provider=None, custom_llm_provider=None,
@ -477,6 +477,9 @@ def mock_completion(
if time_delay is not None: if time_delay is not None:
time.sleep(time_delay) time.sleep(time_delay)
if isinstance(mock_response, dict):
return ModelResponse(**mock_response)
model_response = ModelResponse(stream=stream) model_response = ModelResponse(stream=stream)
if stream is True: if stream is True:
# don't try to access stream object, # don't try to access stream object,

View file

@ -80,25 +80,35 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
Returns Returns
- Tuple[available_tpm, model_tpm, active_projects] - Tuple[available_tpm, model_tpm, active_projects]
- available_tpm: int or null - available_tpm: int or null
- model_tpm: int or null. If available tpm is int, then this will be too. - remaining_model_tpm: int or null. If available tpm is int, then this will be too.
- active_projects: int or null - active_projects: int or null
""" """
active_projects = await self.internal_usage_cache.async_get_cache(model=model) active_projects = await self.internal_usage_cache.async_get_cache(model=model)
current_model_tpm: Optional[int] = await self.llm_router.get_model_group_usage(
model_group=model
)
model_group_info: Optional[ModelGroupInfo] = ( model_group_info: Optional[ModelGroupInfo] = (
self.llm_router.get_model_group_info(model_group=model) self.llm_router.get_model_group_info(model_group=model)
) )
total_model_tpm: Optional[int] = None
if model_group_info is not None and model_group_info.tpm is not None:
total_model_tpm = model_group_info.tpm
remaining_model_tpm: Optional[int] = None
if total_model_tpm is not None and current_model_tpm is not None:
remaining_model_tpm = total_model_tpm - current_model_tpm
elif total_model_tpm is not None:
remaining_model_tpm = total_model_tpm
available_tpm: Optional[int] = None available_tpm: Optional[int] = None
model_tpm: Optional[int] = None
if model_group_info is not None and model_group_info.tpm is not None: if remaining_model_tpm is not None:
model_tpm = model_group_info.tpm
if active_projects is not None: if active_projects is not None:
available_tpm = int(model_group_info.tpm / active_projects) available_tpm = int(remaining_model_tpm / active_projects)
else: else:
available_tpm = model_group_info.tpm available_tpm = remaining_model_tpm
return available_tpm, model_tpm, active_projects return available_tpm, remaining_model_tpm, active_projects
async def async_pre_call_hook( async def async_pre_call_hook(
self, self,
@ -121,6 +131,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
- Check if tpm available - Check if tpm available
- Raise RateLimitError if no tpm available - Raise RateLimitError if no tpm available
""" """
if "model" in data: if "model" in data:
available_tpm, model_tpm, active_projects = await self.check_available_tpm( available_tpm, model_tpm, active_projects = await self.check_available_tpm(
model=data["model"] model=data["model"]

View file

@ -11,6 +11,7 @@ import asyncio
import concurrent import concurrent
import copy import copy
import datetime as datetime_og import datetime as datetime_og
import enum
import hashlib import hashlib
import inspect import inspect
import json import json
@ -90,6 +91,10 @@ from litellm.utils import (
) )
class RoutingArgs(enum.Enum):
ttl = 60 # 1min (RPM/TPM expire key)
class Router: class Router:
model_names: List = [] model_names: List = []
cache_responses: Optional[bool] = False cache_responses: Optional[bool] = False
@ -387,6 +392,11 @@ class Router:
routing_strategy=routing_strategy, routing_strategy=routing_strategy,
routing_strategy_args=routing_strategy_args, routing_strategy_args=routing_strategy_args,
) )
## USAGE TRACKING ##
if isinstance(litellm._async_success_callback, list):
litellm._async_success_callback.append(self.deployment_callback_on_success)
else:
litellm._async_success_callback.append(self.deployment_callback_on_success)
## COOLDOWNS ## ## COOLDOWNS ##
if isinstance(litellm.failure_callback, list): if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure) litellm.failure_callback.append(self.deployment_callback_on_failure)
@ -2636,13 +2646,69 @@ class Router:
time.sleep(_timeout) time.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES: if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
original_exception.max_retries = num_retries setattr(original_exception, "max_retries", num_retries)
original_exception.num_retries = current_attempt setattr(original_exception, "num_retries", current_attempt)
raise original_exception raise original_exception
### HELPER FUNCTIONS ### HELPER FUNCTIONS
async def deployment_callback_on_success(
self,
kwargs, # kwargs to completion
completion_response, # response from completion
start_time,
end_time, # start/end time
):
"""
Track remaining tpm/rpm quota for model in model_list
"""
try:
"""
Update TPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = completion_response["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"global_router:{id}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
await self.cache.async_increment_cache(
key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value
)
except Exception as e:
verbose_router_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
pass
def deployment_callback_on_failure( def deployment_callback_on_failure(
self, self,
kwargs, # kwargs to completion kwargs, # kwargs to completion
@ -3963,6 +4029,35 @@ class Router:
return model_group_info return model_group_info
async def get_model_group_usage(self, model_group: str) -> Optional[int]:
"""
Returns remaining tpm quota for model group
"""
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_keys: List[str] = []
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
tpm_keys.append(
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
)
## TPM
tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache(
keys=tpm_keys
)
tpm_usage: Optional[int] = None
if tpm_usage_list is not None:
for t in tpm_usage_list:
if isinstance(t, int):
if tpm_usage is None:
tpm_usage = 0
tpm_usage += t
return tpm_usage
def get_model_ids(self) -> List[str]: def get_model_ids(self) -> List[str]:
""" """
Returns list of model id's. Returns list of model id's.
@ -4890,7 +4985,7 @@ class Router:
def reset(self): def reset(self):
## clean up on close ## clean up on close
litellm.success_callback = [] litellm.success_callback = []
litellm.__async_success_callback = [] litellm._async_success_callback = []
litellm.failure_callback = [] litellm.failure_callback = []
litellm._async_failure_callback = [] litellm._async_failure_callback = []
self.retry_policy = None self.retry_policy = None

View file

@ -8,7 +8,7 @@ import time
import traceback import traceback
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Tuple from typing import Optional, Tuple
from dotenv import load_dotenv from dotenv import load_dotenv
@ -40,6 +40,40 @@ def dynamic_rate_limit_handler() -> DynamicRateLimitHandler:
return DynamicRateLimitHandler(internal_usage_cache=internal_cache) return DynamicRateLimitHandler(internal_usage_cache=internal_cache)
@pytest.fixture
def mock_response() -> litellm.ModelResponse:
return litellm.ModelResponse(
**{
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1699896916,
"model": "gpt-3.5-turbo-0125",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": '{\n"location": "Boston, MA"\n}',
},
}
],
},
"logprobs": None,
"finish_reason": "tool_calls",
}
],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10},
}
)
@pytest.mark.parametrize("num_projects", [1, 2, 100]) @pytest.mark.parametrize("num_projects", [1, 2, 100])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_available_tpm(num_projects, dynamic_rate_limit_handler): async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
@ -76,3 +110,65 @@ async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
expected_availability = int(model_tpm / num_projects) expected_availability = int(model_tpm / num_projects)
assert availability == expected_availability assert availability == expected_availability
@pytest.mark.asyncio
async def test_base_case(dynamic_rate_limit_handler, mock_response):
"""
If just 1 active project
it should get all the quota
= allow request to go through
- update token usage
- exhaust all tpm with just 1 project
"""
model = "my-fake-model"
model_tpm = 50
setattr(
mock_response,
"usage",
litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10),
)
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
"mock_response": mock_response,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
prev_availability: Optional[int] = None
for _ in range(5):
# check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
## assert availability updated
if prev_availability is not None and availability is not None:
assert availability == prev_availability - 10
print(
"prev_availability={}, availability={}".format(
prev_availability, availability
)
)
prev_availability = availability
# make call
await llm_router.acompletion(
model=model, messages=[{"role": "user", "content": "hey!"}]
)
await asyncio.sleep(3)

View file

@ -1730,3 +1730,96 @@ async def test_router_text_completion_client():
print(responses) print(responses)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.fixture
def mock_response() -> litellm.ModelResponse:
return litellm.ModelResponse(
**{
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1699896916,
"model": "gpt-3.5-turbo-0125",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": '{\n"location": "Boston, MA"\n}',
},
}
],
},
"logprobs": None,
"finish_reason": "tool_calls",
}
],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10},
}
)
@pytest.mark.asyncio
async def test_router_model_usage(mock_response):
model = "my-fake-model"
model_tpm = 100
setattr(
mock_response,
"usage",
litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10),
)
print(f"mock_response: {mock_response}")
model_tpm = 100
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
"mock_response": mock_response,
},
}
]
)
allowed_fails = 1 # allow for changing b/w minutes
for _ in range(2):
try:
_ = await llm_router.acompletion(
model=model, messages=[{"role": "user", "content": "Hey!"}]
)
await asyncio.sleep(3)
initial_usage = await llm_router.get_model_group_usage(model_group=model)
# completion call - 10 tokens
_ = await llm_router.acompletion(
model=model, messages=[{"role": "user", "content": "Hey!"}]
)
await asyncio.sleep(3)
updated_usage = await llm_router.get_model_group_usage(model_group=model)
assert updated_usage == initial_usage + 10 # type: ignore
break
except Exception as e:
if allowed_fails > 0:
print(
f"Decrementing allowed_fails: {allowed_fails}.\nReceived error - {str(e)}"
)
allowed_fails -= 1
else:
print(f"allowed_fails: {allowed_fails}")
raise e