mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(dynamic_rate_limiter.py): passing base case
This commit is contained in:
parent
a028600932
commit
068e8dff5b
5 changed files with 310 additions and 12 deletions
|
@ -428,7 +428,7 @@ def mock_completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: List,
|
messages: List,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
mock_response: Union[str, Exception] = "This is a mock request",
|
mock_response: Union[str, Exception, dict] = "This is a mock request",
|
||||||
mock_tool_calls: Optional[List] = None,
|
mock_tool_calls: Optional[List] = None,
|
||||||
logging=None,
|
logging=None,
|
||||||
custom_llm_provider=None,
|
custom_llm_provider=None,
|
||||||
|
@ -477,6 +477,9 @@ def mock_completion(
|
||||||
if time_delay is not None:
|
if time_delay is not None:
|
||||||
time.sleep(time_delay)
|
time.sleep(time_delay)
|
||||||
|
|
||||||
|
if isinstance(mock_response, dict):
|
||||||
|
return ModelResponse(**mock_response)
|
||||||
|
|
||||||
model_response = ModelResponse(stream=stream)
|
model_response = ModelResponse(stream=stream)
|
||||||
if stream is True:
|
if stream is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
|
|
@ -80,25 +80,35 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||||
Returns
|
Returns
|
||||||
- Tuple[available_tpm, model_tpm, active_projects]
|
- Tuple[available_tpm, model_tpm, active_projects]
|
||||||
- available_tpm: int or null
|
- available_tpm: int or null
|
||||||
- model_tpm: int or null. If available tpm is int, then this will be too.
|
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
|
||||||
- active_projects: int or null
|
- active_projects: int or null
|
||||||
"""
|
"""
|
||||||
active_projects = await self.internal_usage_cache.async_get_cache(model=model)
|
active_projects = await self.internal_usage_cache.async_get_cache(model=model)
|
||||||
|
current_model_tpm: Optional[int] = await self.llm_router.get_model_group_usage(
|
||||||
|
model_group=model
|
||||||
|
)
|
||||||
model_group_info: Optional[ModelGroupInfo] = (
|
model_group_info: Optional[ModelGroupInfo] = (
|
||||||
self.llm_router.get_model_group_info(model_group=model)
|
self.llm_router.get_model_group_info(model_group=model)
|
||||||
)
|
)
|
||||||
|
total_model_tpm: Optional[int] = None
|
||||||
|
if model_group_info is not None and model_group_info.tpm is not None:
|
||||||
|
total_model_tpm = model_group_info.tpm
|
||||||
|
|
||||||
|
remaining_model_tpm: Optional[int] = None
|
||||||
|
if total_model_tpm is not None and current_model_tpm is not None:
|
||||||
|
remaining_model_tpm = total_model_tpm - current_model_tpm
|
||||||
|
elif total_model_tpm is not None:
|
||||||
|
remaining_model_tpm = total_model_tpm
|
||||||
|
|
||||||
available_tpm: Optional[int] = None
|
available_tpm: Optional[int] = None
|
||||||
model_tpm: Optional[int] = None
|
|
||||||
|
|
||||||
if model_group_info is not None and model_group_info.tpm is not None:
|
if remaining_model_tpm is not None:
|
||||||
model_tpm = model_group_info.tpm
|
|
||||||
if active_projects is not None:
|
if active_projects is not None:
|
||||||
available_tpm = int(model_group_info.tpm / active_projects)
|
available_tpm = int(remaining_model_tpm / active_projects)
|
||||||
else:
|
else:
|
||||||
available_tpm = model_group_info.tpm
|
available_tpm = remaining_model_tpm
|
||||||
|
|
||||||
return available_tpm, model_tpm, active_projects
|
return available_tpm, remaining_model_tpm, active_projects
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def async_pre_call_hook(
|
||||||
self,
|
self,
|
||||||
|
@ -121,6 +131,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||||
- Check if tpm available
|
- Check if tpm available
|
||||||
- Raise RateLimitError if no tpm available
|
- Raise RateLimitError if no tpm available
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "model" in data:
|
if "model" in data:
|
||||||
available_tpm, model_tpm, active_projects = await self.check_available_tpm(
|
available_tpm, model_tpm, active_projects = await self.check_available_tpm(
|
||||||
model=data["model"]
|
model=data["model"]
|
||||||
|
|
|
@ -11,6 +11,7 @@ import asyncio
|
||||||
import concurrent
|
import concurrent
|
||||||
import copy
|
import copy
|
||||||
import datetime as datetime_og
|
import datetime as datetime_og
|
||||||
|
import enum
|
||||||
import hashlib
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
@ -90,6 +91,10 @@ from litellm.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RoutingArgs(enum.Enum):
|
||||||
|
ttl = 60 # 1min (RPM/TPM expire key)
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
model_names: List = []
|
model_names: List = []
|
||||||
cache_responses: Optional[bool] = False
|
cache_responses: Optional[bool] = False
|
||||||
|
@ -387,6 +392,11 @@ class Router:
|
||||||
routing_strategy=routing_strategy,
|
routing_strategy=routing_strategy,
|
||||||
routing_strategy_args=routing_strategy_args,
|
routing_strategy_args=routing_strategy_args,
|
||||||
)
|
)
|
||||||
|
## USAGE TRACKING ##
|
||||||
|
if isinstance(litellm._async_success_callback, list):
|
||||||
|
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||||
|
else:
|
||||||
|
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||||
## COOLDOWNS ##
|
## COOLDOWNS ##
|
||||||
if isinstance(litellm.failure_callback, list):
|
if isinstance(litellm.failure_callback, list):
|
||||||
litellm.failure_callback.append(self.deployment_callback_on_failure)
|
litellm.failure_callback.append(self.deployment_callback_on_failure)
|
||||||
|
@ -2636,13 +2646,69 @@ class Router:
|
||||||
time.sleep(_timeout)
|
time.sleep(_timeout)
|
||||||
|
|
||||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||||
original_exception.max_retries = num_retries
|
setattr(original_exception, "max_retries", num_retries)
|
||||||
original_exception.num_retries = current_attempt
|
setattr(original_exception, "num_retries", current_attempt)
|
||||||
|
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
### HELPER FUNCTIONS
|
### HELPER FUNCTIONS
|
||||||
|
|
||||||
|
async def deployment_callback_on_success(
|
||||||
|
self,
|
||||||
|
kwargs, # kwargs to completion
|
||||||
|
completion_response, # response from completion
|
||||||
|
start_time,
|
||||||
|
end_time, # start/end time
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Track remaining tpm/rpm quota for model in model_list
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
Update TPM usage on success
|
||||||
|
"""
|
||||||
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"model_group", None
|
||||||
|
)
|
||||||
|
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
|
return
|
||||||
|
elif isinstance(id, int):
|
||||||
|
id = str(id)
|
||||||
|
|
||||||
|
total_tokens = completion_response["usage"]["total_tokens"]
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime(
|
||||||
|
"%H-%M"
|
||||||
|
) # use the same timezone regardless of system clock
|
||||||
|
|
||||||
|
tpm_key = f"global_router:{id}:tpm:{current_minute}"
|
||||||
|
# ------------
|
||||||
|
# Update usage
|
||||||
|
# ------------
|
||||||
|
# update cache
|
||||||
|
|
||||||
|
## TPM
|
||||||
|
await self.cache.async_increment_cache(
|
||||||
|
key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.error(
|
||||||
|
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
def deployment_callback_on_failure(
|
def deployment_callback_on_failure(
|
||||||
self,
|
self,
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
|
@ -3963,6 +4029,35 @@ class Router:
|
||||||
|
|
||||||
return model_group_info
|
return model_group_info
|
||||||
|
|
||||||
|
async def get_model_group_usage(self, model_group: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Returns remaining tpm quota for model group
|
||||||
|
"""
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime(
|
||||||
|
"%H-%M"
|
||||||
|
) # use the same timezone regardless of system clock
|
||||||
|
tpm_keys: List[str] = []
|
||||||
|
for model in self.model_list:
|
||||||
|
if "model_name" in model and model["model_name"] == model_group:
|
||||||
|
tpm_keys.append(
|
||||||
|
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
|
||||||
|
)
|
||||||
|
|
||||||
|
## TPM
|
||||||
|
tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache(
|
||||||
|
keys=tpm_keys
|
||||||
|
)
|
||||||
|
tpm_usage: Optional[int] = None
|
||||||
|
if tpm_usage_list is not None:
|
||||||
|
for t in tpm_usage_list:
|
||||||
|
if isinstance(t, int):
|
||||||
|
if tpm_usage is None:
|
||||||
|
tpm_usage = 0
|
||||||
|
tpm_usage += t
|
||||||
|
|
||||||
|
return tpm_usage
|
||||||
|
|
||||||
def get_model_ids(self) -> List[str]:
|
def get_model_ids(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns list of model id's.
|
Returns list of model id's.
|
||||||
|
@ -4890,7 +4985,7 @@ class Router:
|
||||||
def reset(self):
|
def reset(self):
|
||||||
## clean up on close
|
## clean up on close
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
litellm.__async_success_callback = []
|
litellm._async_success_callback = []
|
||||||
litellm.failure_callback = []
|
litellm.failure_callback = []
|
||||||
litellm._async_failure_callback = []
|
litellm._async_failure_callback = []
|
||||||
self.retry_policy = None
|
self.retry_policy = None
|
||||||
|
|
|
@ -8,7 +8,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
@ -40,6 +40,40 @@ def dynamic_rate_limit_handler() -> DynamicRateLimitHandler:
|
||||||
return DynamicRateLimitHandler(internal_usage_cache=internal_cache)
|
return DynamicRateLimitHandler(internal_usage_cache=internal_cache)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_response() -> litellm.ModelResponse:
|
||||||
|
return litellm.ModelResponse(
|
||||||
|
**{
|
||||||
|
"id": "chatcmpl-abc123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1699896916,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abc123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": '{\n"location": "Boston, MA"\n}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_projects", [1, 2, 100])
|
@pytest.mark.parametrize("num_projects", [1, 2, 100])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
||||||
|
@ -76,3 +110,65 @@ async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
||||||
expected_availability = int(model_tpm / num_projects)
|
expected_availability = int(model_tpm / num_projects)
|
||||||
|
|
||||||
assert availability == expected_availability
|
assert availability == expected_availability
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
||||||
|
"""
|
||||||
|
If just 1 active project
|
||||||
|
|
||||||
|
it should get all the quota
|
||||||
|
|
||||||
|
= allow request to go through
|
||||||
|
- update token usage
|
||||||
|
- exhaust all tpm with just 1 project
|
||||||
|
"""
|
||||||
|
model = "my-fake-model"
|
||||||
|
model_tpm = 50
|
||||||
|
setattr(
|
||||||
|
mock_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
prev_availability: Optional[int] = None
|
||||||
|
for _ in range(5):
|
||||||
|
# check availability
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
## assert availability updated
|
||||||
|
if prev_availability is not None and availability is not None:
|
||||||
|
assert availability == prev_availability - 10
|
||||||
|
|
||||||
|
print(
|
||||||
|
"prev_availability={}, availability={}".format(
|
||||||
|
prev_availability, availability
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_availability = availability
|
||||||
|
|
||||||
|
# make call
|
||||||
|
await llm_router.acompletion(
|
||||||
|
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
|
@ -1730,3 +1730,96 @@ async def test_router_text_completion_client():
|
||||||
print(responses)
|
print(responses)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_response() -> litellm.ModelResponse:
|
||||||
|
return litellm.ModelResponse(
|
||||||
|
**{
|
||||||
|
"id": "chatcmpl-abc123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1699896916,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abc123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": '{\n"location": "Boston, MA"\n}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_model_usage(mock_response):
|
||||||
|
model = "my-fake-model"
|
||||||
|
model_tpm = 100
|
||||||
|
setattr(
|
||||||
|
mock_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"mock_response: {mock_response}")
|
||||||
|
model_tpm = 100
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed_fails = 1 # allow for changing b/w minutes
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
try:
|
||||||
|
_ = await llm_router.acompletion(
|
||||||
|
model=model, messages=[{"role": "user", "content": "Hey!"}]
|
||||||
|
)
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
initial_usage = await llm_router.get_model_group_usage(model_group=model)
|
||||||
|
|
||||||
|
# completion call - 10 tokens
|
||||||
|
_ = await llm_router.acompletion(
|
||||||
|
model=model, messages=[{"role": "user", "content": "Hey!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
updated_usage = await llm_router.get_model_group_usage(model_group=model)
|
||||||
|
|
||||||
|
assert updated_usage == initial_usage + 10 # type: ignore
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
if allowed_fails > 0:
|
||||||
|
print(
|
||||||
|
f"Decrementing allowed_fails: {allowed_fails}.\nReceived error - {str(e)}"
|
||||||
|
)
|
||||||
|
allowed_fails -= 1
|
||||||
|
else:
|
||||||
|
print(f"allowed_fails: {allowed_fails}")
|
||||||
|
raise e
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue