Merge branch 'main' into litellm_fix_in_mem_usage

This commit is contained in:
Ishaan Jaff 2024-06-22 19:23:37 -07:00 committed by GitHub
commit 8e3a073323
17 changed files with 1332 additions and 126 deletions

View file

@ -152,3 +152,104 @@ litellm_remaining_team_budget_metric{team_alias="QA Prod Bot",team_id="de35b29e-
``` ```
### Dynamic TPM Allocation
Prevent projects from gobbling too much quota.
Dynamically allocate TPM quota to api keys, based on active keys in that minute.
1. Setup config.yaml
```yaml
model_list:
- model_name: my-fake-model
litellm_params:
model: gpt-3.5-turbo
api_key: my-fake-key
mock_response: hello-world
tpm: 60
litellm_settings:
callbacks: ["dynamic_rate_limiter"]
general_settings:
master_key: sk-1234 # OR set `LITELLM_MASTER_KEY=".."` in your .env
database_url: postgres://.. # OR set `DATABASE_URL=".."` in your .env
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```python
"""
- Run 2 concurrent teams calling same model
- model has 60 TPM
- Mock response returns 30 total tokens / request
- Each team will only be able to make 1 request per minute
"""
"""
- Run 2 concurrent teams calling same model
- model has 60 TPM
- Mock response returns 30 total tokens / request
- Each team will only be able to make 1 request per minute
"""
import requests
from openai import OpenAI, RateLimitError
def create_key(api_key: str, base_url: str):
response = requests.post(
url="{}/key/generate".format(base_url),
json={},
headers={
"Authorization": "Bearer {}".format(api_key)
}
)
_response = response.json()
return _response["key"]
key_1 = create_key(api_key="sk-1234", base_url="http://0.0.0.0:4000")
key_2 = create_key(api_key="sk-1234", base_url="http://0.0.0.0:4000")
# call proxy with key 1 - works
openai_client_1 = OpenAI(api_key=key_1, base_url="http://0.0.0.0:4000")
response = openai_client_1.chat.completions.with_raw_response.create(
model="my-fake-model", messages=[{"role": "user", "content": "Hello world!"}],
)
print("Headers for call 1 - {}".format(response.headers))
_response = response.parse()
print("Total tokens for call - {}".format(_response.usage.total_tokens))
# call proxy with key 2 - works
openai_client_2 = OpenAI(api_key=key_2, base_url="http://0.0.0.0:4000")
response = openai_client_2.chat.completions.with_raw_response.create(
model="my-fake-model", messages=[{"role": "user", "content": "Hello world!"}],
)
print("Headers for call 2 - {}".format(response.headers))
_response = response.parse()
print("Total tokens for call - {}".format(_response.usage.total_tokens))
# call proxy with key 2 - fails
try:
openai_client_2.chat.completions.with_raw_response.create(model="my-fake-model", messages=[{"role": "user", "content": "Hey, how's it going?"}])
raise Exception("This should have failed!")
except RateLimitError as e:
print("This was rate limited b/c - {}".format(str(e)))
```
**Expected Response**
```
This was rate limited b/c - Error code: 429 - {'error': {'message': {'error': 'Key=<hashed_token> over available TPM=0. Model TPM=0, Active keys=2'}, 'type': 'None', 'param': 'None', 'code': 429}}
```

View file

@ -37,7 +37,9 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = []
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter", "logfire"] _custom_logger_compatible_callbacks_literal = Literal[
"lago", "openmeter", "logfire", "dynamic_rate_limiter"
]
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = [] callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
_langfuse_default_tags: Optional[ _langfuse_default_tags: Optional[
List[ List[
@ -735,6 +737,7 @@ from .utils import (
client, client,
exception_type, exception_type,
get_optional_params, get_optional_params,
get_response_string,
modify_integration, modify_integration,
token_counter, token_counter,
create_pretrained_tokenizer, create_pretrained_tokenizer,

View file

@ -1,11 +1,12 @@
from datetime import datetime from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Optional, Union
import litellm import litellm
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from .types.services import ServiceTypes, ServiceLoggerPayload
from .integrations.prometheus_services import PrometheusServicesLogger
from .integrations.custom_logger import CustomLogger from .integrations.custom_logger import CustomLogger
from datetime import timedelta from .integrations.prometheus_services import PrometheusServicesLogger
from typing import Union, Optional, TYPE_CHECKING, Any from .types.services import ServiceLoggerPayload, ServiceTypes
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
@ -53,8 +54,8 @@ class ServiceLogging(CustomLogger):
call_type: str, call_type: str,
duration: float, duration: float,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
start_time: Optional[datetime] = None, start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[datetime] = None, end_time: Optional[Union[datetime, float]] = None,
): ):
""" """
- For counting if the redis, postgres call is successful - For counting if the redis, postgres call is successful
@ -92,8 +93,8 @@ class ServiceLogging(CustomLogger):
error: Union[str, Exception], error: Union[str, Exception],
call_type: str, call_type: str,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
start_time: Optional[datetime] = None, start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[datetime] = None, end_time: Optional[Union[float, datetime]] = None,
): ):
""" """
- For counting if the redis, postgres call is unsuccessful - For counting if the redis, postgres call is unsuccessful

View file

@ -14,6 +14,7 @@ import json
import logging import logging
import time import time
import traceback import traceback
from datetime import timedelta
from typing import Any, BinaryIO, List, Literal, Optional, Union from typing import Any, BinaryIO, List, Literal, Optional, Union
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
@ -92,9 +93,22 @@ class InMemoryCache(BaseCache):
else: else:
self.set_cache(key=cache_key, value=cache_value) self.set_cache(key=cache_key, value=cache_value)
if time.time() - self.last_cleaned > self.default_ttl: if time.time() - self.last_cleaned > self.default_ttl:
asyncio.create_task(self.clean_up_in_memory_cache()) asyncio.create_task(self.clean_up_in_memory_cache())
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
"""
Add value to set
"""
# get the value
init_value = self.get_cache(key=key) or set()
for val in value:
init_value.add(val)
self.set_cache(key, init_value, ttl=ttl)
return value
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
if key in self.cache_dict: if key in self.cache_dict:
if key in self.ttl_dict: if key in self.ttl_dict:
@ -363,6 +377,7 @@ class RedisCache(BaseCache):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
call_type="async_set_cache",
) )
) )
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
@ -482,6 +497,80 @@ class RedisCache(BaseCache):
cache_value, cache_value,
) )
async def async_set_cache_sadd(
self, key, value: List, ttl: Optional[float], **kwargs
):
start_time = time.time()
try:
_redis_client = self.init_async_client()
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
call_type="async_set_cache_sadd",
)
)
# NON blocking - notify users Redis is throwing an exception
verbose_logger.error(
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
raise e
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client:
print_verbose(
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
await redis_client.sadd(key, *value)
if ttl is not None:
_td = timedelta(seconds=ttl)
await redis_client.expire(key, _td)
print_verbose(
f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="async_set_cache_sadd",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="async_set_cache_sadd",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
# NON blocking - notify users Redis is throwing an exception
verbose_logger.error(
"LiteLLM Redis Caching: async set_cache_sadd() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
async def batch_cache_write(self, key, value, **kwargs): async def batch_cache_write(self, key, value, **kwargs):
print_verbose( print_verbose(
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}", f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
@ -1506,7 +1595,7 @@ class DualCache(BaseCache):
key, value, **kwargs key, value, **kwargs
) )
if self.redis_cache is not None and local_only == False: if self.redis_cache is not None and local_only is False:
result = await self.redis_cache.async_increment(key, value, **kwargs) result = await self.redis_cache.async_increment(key, value, **kwargs)
return result return result
@ -1515,6 +1604,38 @@ class DualCache(BaseCache):
verbose_logger.debug(traceback.format_exc()) verbose_logger.debug(traceback.format_exc())
raise e raise e
async def async_set_cache_sadd(
self, key, value: List, local_only: bool = False, **kwargs
) -> None:
"""
Add value to a set
Key - the key in cache
Value - str - the value you want to add to the set
Returns - None
"""
try:
if self.in_memory_cache is not None:
_ = await self.in_memory_cache.async_set_cache_sadd(
key, value, ttl=kwargs.get("ttl", None)
)
if self.redis_cache is not None and local_only is False:
_ = await self.redis_cache.async_set_cache_sadd(
key, value, ttl=kwargs.get("ttl", None) ** kwargs
)
return None
except Exception as e:
verbose_logger.error(
"LiteLLM Cache: Excepton async set_cache_sadd: {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise e
def flush_cache(self): def flush_cache(self):
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache() self.in_memory_cache.flush_cache()

View file

@ -105,8 +105,8 @@ class OpenTelemetry(CustomLogger):
self, self,
payload: ServiceLoggerPayload, payload: ServiceLoggerPayload,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
start_time: Optional[datetime] = None, start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[datetime] = None, end_time: Optional[Union[datetime, float]] = None,
): ):
from datetime import datetime from datetime import datetime
@ -144,8 +144,8 @@ class OpenTelemetry(CustomLogger):
self, self,
payload: ServiceLoggerPayload, payload: ServiceLoggerPayload,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
start_time: Optional[datetime] = None, start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[datetime] = None, end_time: Optional[Union[float, datetime]] = None,
): ):
from datetime import datetime from datetime import datetime

View file

@ -19,7 +19,8 @@ from litellm import (
turn_off_message_logging, turn_off_message_logging,
verbose_logger, verbose_logger,
) )
from litellm.caching import InMemoryCache, S3Cache
from litellm.caching import InMemoryCache, S3Cache, DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import ( from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging, redact_message_input_output_from_logging,
@ -1899,7 +1900,11 @@ def set_callbacks(callback_list, function_id=None):
def _init_custom_logger_compatible_class( def _init_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal, logging_integration: litellm._custom_logger_compatible_callbacks_literal,
) -> Callable: internal_usage_cache: Optional[DualCache],
llm_router: Optional[
Any
], # expect litellm.Router, but typing errors due to circular import
) -> CustomLogger:
if logging_integration == "lago": if logging_integration == "lago":
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if isinstance(callback, LagoLogger): if isinstance(callback, LagoLogger):
@ -1935,3 +1940,58 @@ def _init_custom_logger_compatible_class(
_otel_logger = OpenTelemetry(config=otel_config) _otel_logger = OpenTelemetry(config=otel_config)
_in_memory_loggers.append(_otel_logger) _in_memory_loggers.append(_otel_logger)
return _otel_logger # type: ignore return _otel_logger # type: ignore
elif logging_integration == "dynamic_rate_limiter":
from litellm.proxy.hooks.dynamic_rate_limiter import (
_PROXY_DynamicRateLimitHandler,
)
for callback in _in_memory_loggers:
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
return callback # type: ignore
if internal_usage_cache is None:
raise Exception(
"Internal Error: Cache cannot be empty - internal_usage_cache={}".format(
internal_usage_cache
)
)
dynamic_rate_limiter_obj = _PROXY_DynamicRateLimitHandler(
internal_usage_cache=internal_usage_cache
)
if llm_router is not None and isinstance(llm_router, litellm.Router):
dynamic_rate_limiter_obj.update_variables(llm_router=llm_router)
_in_memory_loggers.append(dynamic_rate_limiter_obj)
return dynamic_rate_limiter_obj # type: ignore
def get_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
) -> Optional[CustomLogger]:
if logging_integration == "lago":
for callback in _in_memory_loggers:
if isinstance(callback, LagoLogger):
return callback
elif logging_integration == "openmeter":
for callback in _in_memory_loggers:
if isinstance(callback, OpenMeterLogger):
return callback
elif logging_integration == "logfire":
if "LOGFIRE_TOKEN" not in os.environ:
raise ValueError("LOGFIRE_TOKEN not found in environment variables")
from litellm.integrations.opentelemetry import OpenTelemetry
for callback in _in_memory_loggers:
if isinstance(callback, OpenTelemetry):
return callback # type: ignore
elif logging_integration == "dynamic_rate_limiter":
from litellm.proxy.hooks.dynamic_rate_limiter import (
_PROXY_DynamicRateLimitHandler,
)
for callback in _in_memory_loggers:
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
return callback # type: ignore
return None

View file

@ -428,7 +428,7 @@ def mock_completion(
model: str, model: str,
messages: List, messages: List,
stream: Optional[bool] = False, stream: Optional[bool] = False,
mock_response: Union[str, Exception] = "This is a mock request", mock_response: Union[str, Exception, dict] = "This is a mock request",
mock_tool_calls: Optional[List] = None, mock_tool_calls: Optional[List] = None,
logging=None, logging=None,
custom_llm_provider=None, custom_llm_provider=None,
@ -477,6 +477,9 @@ def mock_completion(
if time_delay is not None: if time_delay is not None:
time.sleep(time_delay) time.sleep(time_delay)
if isinstance(mock_response, dict):
return ModelResponse(**mock_response)
model_response = ModelResponse(stream=stream) model_response = ModelResponse(stream=stream)
if stream is True: if stream is True:
# don't try to access stream object, # don't try to access stream object,

View file

@ -1,61 +1,10 @@
environment_variables:
LANGFUSE_PUBLIC_KEY: Q6K8MQN6L7sPYSJiFKM9eNrETOx6V/FxVPup4FqdKsZK1hyR4gyanlQ2KHLg5D5afng99uIt0JCEQ2jiKF9UxFvtnb4BbJ4qpeceH+iK8v/bdg==
LANGFUSE_SECRET_KEY: 5xQ7KMa6YMLsm+H/Pf1VmlqWq1NON5IoCxABhkUBeSck7ftsj2CmpkL2ZwrxwrktgiTUBH+3gJYBX+XBk7lqOOUpvmiLjol/E5lCqq0M1CqLWA==
SLACK_WEBHOOK_URL: RJjhS0Hhz0/s07sCIf1OTXmTGodpK9L2K9p953Z+fOX0l2SkPFT6mB9+yIrLufmlwEaku5NNEBKy//+AG01yOd+7wV1GhK65vfj3B/gTN8t5cuVnR4vFxKY5Rx4eSGLtzyAs+aIBTp4GoNXDIjroCqfCjPkItEZWCg==
general_settings:
alerting:
- slack
alerting_threshold: 300
database_connection_pool_limit: 100
database_connection_timeout: 60
disable_master_key_return: true
health_check_interval: 300
proxy_batch_write_at: 60
ui_access_mode: all
# master_key: sk-1234
litellm_settings:
allowed_fails: 3
failure_callback:
- prometheus
num_retries: 3
service_callback:
- prometheus_system
success_callback:
- langfuse
- prometheus
- langsmith
model_list: model_list:
- litellm_params: - model_name: my-fake-model
model: gpt-3.5-turbo litellm_params:
model_name: gpt-3.5-turbo model: gpt-3.5-turbo
- litellm_params: api_key: my-fake-key
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ mock_response: hello-world
api_key: my-fake-key tpm: 60
model: openai/my-fake-model
stream_timeout: 0.001 litellm_settings:
model_name: fake-openai-endpoint callbacks: ["dynamic_rate_limiter"]
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key
model: openai/my-fake-model-2
stream_timeout: 0.001
model_name: fake-openai-endpoint
- litellm_params:
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: 2023-07-01-preview
model: azure/chatgpt-v-2
stream_timeout: 0.001
model_name: azure-gpt-3.5
- litellm_params:
api_key: os.environ/OPENAI_API_KEY
model: text-embedding-ada-002
model_name: text-embedding-ada-002
- litellm_params:
model: text-completion-openai/gpt-3.5-turbo-instruct
model_name: gpt-instruct
router_settings:
enable_pre_call_checks: true
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT

View file

@ -30,6 +30,7 @@ model_list:
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY
api_version: 2024-02-15-preview api_version: 2024-02-15-preview
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
tpm: 100
model_name: gpt-3.5-turbo model_name: gpt-3.5-turbo
- litellm_params: - litellm_params:
model: anthropic.claude-3-sonnet-20240229-v1:0 model: anthropic.claude-3-sonnet-20240229-v1:0
@ -40,6 +41,7 @@ model_list:
api_version: 2024-02-15-preview api_version: 2024-02-15-preview
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
drop_params: True drop_params: True
tpm: 100
model_name: gpt-3.5-turbo model_name: gpt-3.5-turbo
- model_name: tts - model_name: tts
litellm_params: litellm_params:
@ -67,8 +69,7 @@ model_list:
max_input_tokens: 80920 max_input_tokens: 80920
litellm_settings: litellm_settings:
success_callback: ["langfuse"] callbacks: ["dynamic_rate_limiter"]
failure_callback: ["langfuse"]
# default_team_settings: # default_team_settings:
# - team_id: proj1 # - team_id: proj1
# success_callback: ["langfuse"] # success_callback: ["langfuse"]

View file

@ -0,0 +1,205 @@
# What is this?
## Allocates dynamic tpm/rpm quota for a project based on current traffic
## Tracks num active projects per minute
import asyncio
import sys
import traceback
from datetime import datetime
from typing import List, Literal, Optional, Tuple, Union
from fastapi import HTTPException
import litellm
from litellm import ModelResponse, Router
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.router import ModelGroupInfo
from litellm.utils import get_utc_datetime
class DynamicRateLimiterCache:
"""
Thin wrapper on DualCache for this file.
Track number of active projects calling a model.
"""
def __init__(self, cache: DualCache) -> None:
self.cache = cache
self.ttl = 60 # 1 min ttl
async def async_get_cache(self, model: str) -> Optional[int]:
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
key_name = "{}:{}".format(current_minute, model)
_response = await self.cache.async_get_cache(key=key_name)
response: Optional[int] = None
if _response is not None:
response = len(_response)
return response
async def async_set_cache_sadd(self, model: str, value: List):
"""
Add value to set.
Parameters:
- model: str, the name of the model group
- value: str, the team id
Returns:
- None
Raises:
- Exception, if unable to connect to cache client (if redis caching enabled)
"""
try:
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
key_name = "{}:{}".format(current_minute, model)
await self.cache.async_set_cache_sadd(
key=key_name, value=value, ttl=self.ttl
)
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise e
class _PROXY_DynamicRateLimitHandler(CustomLogger):
# Class variables or attributes
def __init__(self, internal_usage_cache: DualCache):
self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache)
def update_variables(self, llm_router: Router):
self.llm_router = llm_router
async def check_available_tpm(
self, model: str
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
"""
For a given model, get its available tpm
Returns
- Tuple[available_tpm, model_tpm, active_projects]
- available_tpm: int or null - always 0 or positive.
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
- active_projects: int or null
"""
active_projects = await self.internal_usage_cache.async_get_cache(model=model)
current_model_tpm: Optional[int] = await self.llm_router.get_model_group_usage(
model_group=model
)
model_group_info: Optional[ModelGroupInfo] = (
self.llm_router.get_model_group_info(model_group=model)
)
total_model_tpm: Optional[int] = None
if model_group_info is not None and model_group_info.tpm is not None:
total_model_tpm = model_group_info.tpm
remaining_model_tpm: Optional[int] = None
if total_model_tpm is not None and current_model_tpm is not None:
remaining_model_tpm = total_model_tpm - current_model_tpm
elif total_model_tpm is not None:
remaining_model_tpm = total_model_tpm
available_tpm: Optional[int] = None
if remaining_model_tpm is not None:
if active_projects is not None:
available_tpm = int(remaining_model_tpm / active_projects)
else:
available_tpm = remaining_model_tpm
if available_tpm is not None and available_tpm < 0:
available_tpm = 0
return available_tpm, remaining_model_tpm, active_projects
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
) -> Optional[
Union[Exception, str, dict]
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
"""
- For a model group
- Check if tpm available
- Raise RateLimitError if no tpm available
"""
if "model" in data:
available_tpm, model_tpm, active_projects = await self.check_available_tpm(
model=data["model"]
)
if available_tpm is not None and available_tpm == 0:
raise HTTPException(
status_code=429,
detail={
"error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format(
user_api_key_dict.api_key,
available_tpm,
model_tpm,
active_projects,
)
},
)
elif available_tpm is not None:
## UPDATE CACHE WITH ACTIVE PROJECT
asyncio.create_task(
self.internal_usage_cache.async_set_cache_sadd( # this is a set
model=data["model"], # type: ignore
value=[user_api_key_dict.token or "default_key"],
)
)
return None
async def async_post_call_success_hook(
self, user_api_key_dict: UserAPIKeyAuth, response
):
try:
if isinstance(response, ModelResponse):
model_info = self.llm_router.get_model_info(
id=response._hidden_params["model_id"]
)
assert (
model_info is not None
), "Model info for model with id={} is None".format(
response._hidden_params["model_id"]
)
available_tpm, remaining_model_tpm, active_projects = (
await self.check_available_tpm(model=model_info["model_name"])
)
response._hidden_params["additional_headers"] = {
"x-litellm-model_group": model_info["model_name"],
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
"x-ratelimit-remaining-model-tokens": remaining_model_tpm,
"x-ratelimit-current-active-projects": active_projects,
}
return response
return await super().async_post_call_success_hook(
user_api_key_dict, response
)
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
return response

View file

@ -433,6 +433,7 @@ def get_custom_headers(
version: Optional[str] = None, version: Optional[str] = None,
model_region: Optional[str] = None, model_region: Optional[str] = None,
fastest_response_batch_completion: Optional[bool] = None, fastest_response_batch_completion: Optional[bool] = None,
**kwargs,
) -> dict: ) -> dict:
exclude_values = {"", None} exclude_values = {"", None}
headers = { headers = {
@ -448,6 +449,7 @@ def get_custom_headers(
if fastest_response_batch_completion is not None if fastest_response_batch_completion is not None
else None else None
), ),
**{k: str(v) for k, v in kwargs.items()},
} }
try: try:
return { return {
@ -2644,7 +2646,9 @@ async def startup_event():
redis_cache=redis_usage_cache redis_cache=redis_usage_cache
) # used by parallel request limiter for rate limiting keys across instances ) # used by parallel request limiter for rate limiting keys across instances
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks(
llm_router=llm_router
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types: if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types:
asyncio.create_task( asyncio.create_task(
@ -3061,6 +3065,14 @@ async def chat_completion(
headers=custom_headers, headers=custom_headers,
) )
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response
)
hidden_params = getattr(response, "_hidden_params", {}) or {}
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
fastapi_response.headers.update( fastapi_response.headers.update(
get_custom_headers( get_custom_headers(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
@ -3070,14 +3082,10 @@ async def chat_completion(
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion, fastest_response_batch_completion=fastest_response_batch_completion,
**additional_headers,
) )
) )
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response
)
return response return response
except RejectedRequestError as e: except RejectedRequestError as e:
_data = e.request_data _data = e.request_data
@ -3116,11 +3124,10 @@ async def chat_completion(
except Exception as e: except Exception as e:
data["litellm_status"] = "fail" # used for alerting data["litellm_status"] = "fail" # used for alerting
verbose_proxy_logger.error( verbose_proxy_logger.error(
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}".format( "litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format(
get_error_message_str(e=e) get_error_message_str(e=e), traceback.format_exc()
) )
) )
verbose_proxy_logger.debug(traceback.format_exc())
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )

View file

@ -229,31 +229,32 @@ class ProxyLogging:
if redis_cache is not None: if redis_cache is not None:
self.internal_usage_cache.redis_cache = redis_cache self.internal_usage_cache.redis_cache = redis_cache
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None):
print_verbose("INITIALIZING LITELLM CALLBACKS!")
self.service_logging_obj = ServiceLogging() self.service_logging_obj = ServiceLogging()
litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore
litellm.callbacks.append(self.max_budget_limiter) litellm.callbacks.append(self.max_budget_limiter) # type: ignore
litellm.callbacks.append(self.cache_control_check) litellm.callbacks.append(self.cache_control_check) # type: ignore
litellm.callbacks.append(self.service_logging_obj) litellm.callbacks.append(self.service_logging_obj) # type: ignore
litellm.success_callback.append( litellm.success_callback.append(
self.slack_alerting_instance.response_taking_too_long_callback self.slack_alerting_instance.response_taking_too_long_callback
) )
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, str): if isinstance(callback, str):
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
callback callback,
internal_usage_cache=self.internal_usage_cache,
llm_router=llm_router,
) )
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
litellm.input_callback.append(callback) litellm.input_callback.append(callback) # type: ignore
if callback not in litellm.success_callback: if callback not in litellm.success_callback:
litellm.success_callback.append(callback) litellm.success_callback.append(callback) # type: ignore
if callback not in litellm.failure_callback: if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback) litellm.failure_callback.append(callback) # type: ignore
if callback not in litellm._async_success_callback: if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback) litellm._async_success_callback.append(callback) # type: ignore
if callback not in litellm._async_failure_callback: if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback) litellm._async_failure_callback.append(callback) # type: ignore
if ( if (
len(litellm.input_callback) > 0 len(litellm.input_callback) > 0
@ -301,10 +302,19 @@ class ProxyLogging:
try: try:
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars( _callback: Optional[CustomLogger] = None
callback.__class__ if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else:
_callback = callback # type: ignore
if (
_callback is not None
and isinstance(_callback, CustomLogger)
and "async_pre_call_hook" in vars(_callback.__class__)
): ):
response = await callback.async_pre_call_hook( response = await _callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"], cache=self.call_details["user_api_key_cache"],
data=data, data=data,
@ -574,8 +584,15 @@ class ProxyLogging:
for callback in litellm.callbacks: for callback in litellm.callbacks:
try: try:
if isinstance(callback, CustomLogger): _callback: Optional[CustomLogger] = None
await callback.async_post_call_failure_hook( if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_failure_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
original_exception=original_exception, original_exception=original_exception,
) )
@ -596,8 +613,15 @@ class ProxyLogging:
""" """
for callback in litellm.callbacks: for callback in litellm.callbacks:
try: try:
if isinstance(callback, CustomLogger): _callback: Optional[CustomLogger] = None
await callback.async_post_call_success_hook( if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response user_api_key_dict=user_api_key_dict, response=response
) )
except Exception as e: except Exception as e:
@ -615,14 +639,25 @@ class ProxyLogging:
Covers: Covers:
1. /chat/completions 1. /chat/completions
""" """
for callback in litellm.callbacks: response_str: Optional[str] = None
try: if isinstance(response, ModelResponse):
if isinstance(callback, CustomLogger): response_str = litellm.get_response_string(response_obj=response)
await callback.async_post_call_streaming_hook( if response_str is not None:
user_api_key_dict=user_api_key_dict, response=response for callback in litellm.callbacks:
) try:
except Exception as e: _callback: Optional[CustomLogger] = None
raise e if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=response_str
)
except Exception as e:
raise e
return response return response
async def post_call_streaming_hook( async def post_call_streaming_hook(

View file

@ -11,6 +11,7 @@ import asyncio
import concurrent import concurrent
import copy import copy
import datetime as datetime_og import datetime as datetime_og
import enum
import hashlib import hashlib
import inspect import inspect
import json import json
@ -90,6 +91,10 @@ from litellm.utils import (
) )
class RoutingArgs(enum.Enum):
ttl = 60 # 1min (RPM/TPM expire key)
class Router: class Router:
model_names: List = [] model_names: List = []
cache_responses: Optional[bool] = False cache_responses: Optional[bool] = False
@ -387,6 +392,11 @@ class Router:
routing_strategy=routing_strategy, routing_strategy=routing_strategy,
routing_strategy_args=routing_strategy_args, routing_strategy_args=routing_strategy_args,
) )
## USAGE TRACKING ##
if isinstance(litellm._async_success_callback, list):
litellm._async_success_callback.append(self.deployment_callback_on_success)
else:
litellm._async_success_callback.append(self.deployment_callback_on_success)
## COOLDOWNS ## ## COOLDOWNS ##
if isinstance(litellm.failure_callback, list): if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure) litellm.failure_callback.append(self.deployment_callback_on_failure)
@ -2640,13 +2650,69 @@ class Router:
time.sleep(_timeout) time.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES: if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
original_exception.max_retries = num_retries setattr(original_exception, "max_retries", num_retries)
original_exception.num_retries = current_attempt setattr(original_exception, "num_retries", current_attempt)
raise original_exception raise original_exception
### HELPER FUNCTIONS ### HELPER FUNCTIONS
async def deployment_callback_on_success(
self,
kwargs, # kwargs to completion
completion_response, # response from completion
start_time,
end_time, # start/end time
):
"""
Track remaining tpm/rpm quota for model in model_list
"""
try:
"""
Update TPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = completion_response["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"global_router:{id}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
await self.cache.async_increment_cache(
key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value
)
except Exception as e:
verbose_router_logger.error(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
pass
def deployment_callback_on_failure( def deployment_callback_on_failure(
self, self,
kwargs, # kwargs to completion kwargs, # kwargs to completion
@ -3812,10 +3878,39 @@ class Router:
model_group_info: Optional[ModelGroupInfo] = None model_group_info: Optional[ModelGroupInfo] = None
total_tpm: Optional[int] = None
total_rpm: Optional[int] = None
for model in self.model_list: for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group: if "model_name" in model and model["model_name"] == model_group:
# model in model group found # # model in model group found #
litellm_params = LiteLLM_Params(**model["litellm_params"]) litellm_params = LiteLLM_Params(**model["litellm_params"])
# get model tpm
_deployment_tpm: Optional[int] = None
if _deployment_tpm is None:
_deployment_tpm = model.get("tpm", None)
if _deployment_tpm is None:
_deployment_tpm = model.get("litellm_params", {}).get("tpm", None)
if _deployment_tpm is None:
_deployment_tpm = model.get("model_info", {}).get("tpm", None)
if _deployment_tpm is not None:
if total_tpm is None:
total_tpm = 0
total_tpm += _deployment_tpm # type: ignore
# get model rpm
_deployment_rpm: Optional[int] = None
if _deployment_rpm is None:
_deployment_rpm = model.get("rpm", None)
if _deployment_rpm is None:
_deployment_rpm = model.get("litellm_params", {}).get("rpm", None)
if _deployment_rpm is None:
_deployment_rpm = model.get("model_info", {}).get("rpm", None)
if _deployment_rpm is not None:
if total_rpm is None:
total_rpm = 0
total_rpm += _deployment_rpm # type: ignore
# get model info # get model info
try: try:
model_info = litellm.get_model_info(model=litellm_params.model) model_info = litellm.get_model_info(model=litellm_params.model)
@ -3929,8 +4024,44 @@ class Router:
"supported_openai_params" "supported_openai_params"
] ]
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
if total_tpm is not None and model_group_info is not None:
model_group_info.tpm = total_tpm
if total_rpm is not None and model_group_info is not None:
model_group_info.rpm = total_rpm
return model_group_info return model_group_info
async def get_model_group_usage(self, model_group: str) -> Optional[int]:
"""
Returns remaining tpm quota for model group
"""
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_keys: List[str] = []
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
tpm_keys.append(
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
)
## TPM
tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache(
keys=tpm_keys
)
tpm_usage: Optional[int] = None
if tpm_usage_list is not None:
for t in tpm_usage_list:
if isinstance(t, int):
if tpm_usage is None:
tpm_usage = 0
tpm_usage += t
return tpm_usage
def get_model_ids(self) -> List[str]: def get_model_ids(self) -> List[str]:
""" """
Returns list of model id's. Returns list of model id's.
@ -4858,7 +4989,7 @@ class Router:
def reset(self): def reset(self):
## clean up on close ## clean up on close
litellm.success_callback = [] litellm.success_callback = []
litellm.__async_success_callback = [] litellm._async_success_callback = []
litellm.failure_callback = [] litellm.failure_callback = []
litellm._async_failure_callback = [] litellm._async_failure_callback = []
self.retry_policy = None self.retry_policy = None

View file

@ -0,0 +1,486 @@
# What is this?
## Unit tests for 'dynamic_rate_limiter.py`
import asyncio
import os
import random
import sys
import time
import traceback
import uuid
from datetime import datetime
from typing import Optional, Tuple
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import DualCache, Router
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.dynamic_rate_limiter import (
_PROXY_DynamicRateLimitHandler as DynamicRateLimitHandler,
)
"""
Basic test cases:
- If 1 'active' project => give all tpm
- If 2 'active' projects => divide tpm in 2
"""
@pytest.fixture
def dynamic_rate_limit_handler() -> DynamicRateLimitHandler:
internal_cache = DualCache()
return DynamicRateLimitHandler(internal_usage_cache=internal_cache)
@pytest.fixture
def mock_response() -> litellm.ModelResponse:
return litellm.ModelResponse(
**{
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1699896916,
"model": "gpt-3.5-turbo-0125",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": '{\n"location": "Boston, MA"\n}',
},
}
],
},
"logprobs": None,
"finish_reason": "tool_calls",
}
],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10},
}
)
@pytest.fixture
def user_api_key_auth() -> UserAPIKeyAuth:
return UserAPIKeyAuth()
@pytest.mark.parametrize("num_projects", [1, 2, 100])
@pytest.mark.asyncio
async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
model = "my-fake-model"
## SET CACHE W/ ACTIVE PROJECTS
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
model=model, value=projects
)
model_tpm = 100
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
## CHECK AVAILABLE TPM PER PROJECT
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
expected_availability = int(model_tpm / num_projects)
assert availability == expected_availability
@pytest.mark.asyncio
async def test_rate_limit_raised(dynamic_rate_limit_handler, user_api_key_auth):
"""
Unit test. Tests if rate limit error raised when quota exhausted.
"""
from fastapi import HTTPException
model = "my-fake-model"
## SET CACHE W/ ACTIVE PROJECTS
projects = [str(uuid.uuid4())]
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
model=model, value=projects
)
model_tpm = 0
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
## CHECK AVAILABLE TPM PER PROJECT
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
expected_availability = int(model_tpm / 1)
assert availability == expected_availability
## CHECK if exception raised
try:
await dynamic_rate_limit_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_auth,
cache=DualCache(),
data={"model": model},
call_type="completion",
)
pytest.fail("Expected this to raise HTTPexception")
except HTTPException as e:
assert e.status_code == 429 # check if rate limit error raised
pass
@pytest.mark.asyncio
async def test_base_case(dynamic_rate_limit_handler, mock_response):
"""
If just 1 active project
it should get all the quota
= allow request to go through
- update token usage
- exhaust all tpm with just 1 project
- assert ratelimiterror raised at 100%+1 tpm
"""
model = "my-fake-model"
## model tpm - 50
model_tpm = 50
## tpm per request - 10
setattr(
mock_response,
"usage",
litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10),
)
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
"mock_response": mock_response,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
prev_availability: Optional[int] = None
allowed_fails = 1
for _ in range(5):
try:
# check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
## assert availability updated
if prev_availability is not None and availability is not None:
assert availability == prev_availability - 10
print(
"prev_availability={}, availability={}".format(
prev_availability, availability
)
)
prev_availability = availability
# make call
await llm_router.acompletion(
model=model, messages=[{"role": "user", "content": "hey!"}]
)
await asyncio.sleep(3)
except Exception:
if allowed_fails > 0:
allowed_fails -= 1
else:
raise
@pytest.mark.asyncio
async def test_update_cache(
dynamic_rate_limit_handler, mock_response, user_api_key_auth
):
"""
Check if active project correctly updated
"""
model = "my-fake-model"
model_tpm = 50
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
"mock_response": mock_response,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
## INITIAL ACTIVE PROJECTS - ASSERT NONE
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
assert active_projects is None
## MAKE CALL
await dynamic_rate_limit_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_auth,
cache=DualCache(),
data={"model": model},
call_type="completion",
)
await asyncio.sleep(2)
## INITIAL ACTIVE PROJECTS - ASSERT 1
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
assert active_projects == 1
@pytest.mark.parametrize("num_projects", [2])
@pytest.mark.asyncio
async def test_multiple_projects(
dynamic_rate_limit_handler, mock_response, num_projects
):
"""
If 2 active project
it should split 50% each
- assert available tpm is 0 after 50%+1 tpm calls
"""
model = "my-fake-model"
model_tpm = 50
total_tokens_per_call = 10
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
available_tpm_per_project = int(model_tpm / num_projects)
## SET CACHE W/ ACTIVE PROJECTS
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
model=model, value=projects
)
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
setattr(
mock_response,
"usage",
litellm.Usage(
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
),
)
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
"mock_response": mock_response,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
prev_availability: Optional[int] = None
print("expected_runs: {}".format(expected_runs))
for i in range(expected_runs + 1):
# check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
## assert availability updated
if prev_availability is not None and availability is not None:
assert (
availability == prev_availability - step_tokens_per_call_per_project
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
availability,
prev_availability - 10,
i,
step_tokens_per_call_per_project,
model_tpm,
)
print(
"prev_availability={}, availability={}".format(
prev_availability, availability
)
)
prev_availability = availability
# make call
await llm_router.acompletion(
model=model, messages=[{"role": "user", "content": "hey!"}]
)
await asyncio.sleep(3)
# check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
assert availability == 0
@pytest.mark.parametrize("num_projects", [2])
@pytest.mark.asyncio
async def test_multiple_projects_e2e(
dynamic_rate_limit_handler, mock_response, num_projects
):
"""
2 parallel calls with different keys, same model
If 2 active project
it should split 50% each
- assert available tpm is 0 after 50%+1 tpm calls
"""
model = "my-fake-model"
model_tpm = 50
total_tokens_per_call = 10
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
available_tpm_per_project = int(model_tpm / num_projects)
## SET CACHE W/ ACTIVE PROJECTS
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
model=model, value=projects
)
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
setattr(
mock_response,
"usage",
litellm.Usage(
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
),
)
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
"mock_response": mock_response,
},
}
]
)
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
prev_availability: Optional[int] = None
print("expected_runs: {}".format(expected_runs))
for i in range(expected_runs + 1):
# check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
## assert availability updated
if prev_availability is not None and availability is not None:
assert (
availability == prev_availability - step_tokens_per_call_per_project
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
availability,
prev_availability - 10,
i,
step_tokens_per_call_per_project,
model_tpm,
)
print(
"prev_availability={}, availability={}".format(
prev_availability, availability
)
)
prev_availability = availability
# make call
await llm_router.acompletion(
model=model, messages=[{"role": "user", "content": "hey!"}]
)
await asyncio.sleep(3)
# check availability
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
model=model
)
assert availability == 0

View file

@ -1730,3 +1730,99 @@ async def test_router_text_completion_client():
print(responses) print(responses)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.fixture
def mock_response() -> litellm.ModelResponse:
return litellm.ModelResponse(
**{
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1699896916,
"model": "gpt-3.5-turbo-0125",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": '{\n"location": "Boston, MA"\n}',
},
}
],
},
"logprobs": None,
"finish_reason": "tool_calls",
}
],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10},
}
)
@pytest.mark.asyncio
async def test_router_model_usage(mock_response):
"""
Test if tracking used model tpm works as expected
"""
model = "my-fake-model"
model_tpm = 100
setattr(
mock_response,
"usage",
litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10),
)
print(f"mock_response: {mock_response}")
model_tpm = 100
llm_router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "my-key",
"api_base": "my-base",
"tpm": model_tpm,
"mock_response": mock_response,
},
}
]
)
allowed_fails = 1 # allow for changing b/w minutes
for _ in range(2):
try:
_ = await llm_router.acompletion(
model=model, messages=[{"role": "user", "content": "Hey!"}]
)
await asyncio.sleep(3)
initial_usage = await llm_router.get_model_group_usage(model_group=model)
# completion call - 10 tokens
_ = await llm_router.acompletion(
model=model, messages=[{"role": "user", "content": "Hey!"}]
)
await asyncio.sleep(3)
updated_usage = await llm_router.get_model_group_usage(model_group=model)
assert updated_usage == initial_usage + 10 # type: ignore
break
except Exception as e:
if allowed_fails > 0:
print(
f"Decrementing allowed_fails: {allowed_fails}.\nReceived error - {str(e)}"
)
allowed_fails -= 1
else:
print(f"allowed_fails: {allowed_fails}")
raise e

View file

@ -442,6 +442,8 @@ class ModelGroupInfo(BaseModel):
"chat", "embedding", "completion", "image_generation", "audio_transcription" "chat", "embedding", "completion", "image_generation", "audio_transcription"
] ]
] = Field(default="chat") ] = Field(default="chat")
tpm: Optional[int] = None
rpm: Optional[int] = None
supports_parallel_function_calling: bool = Field(default=False) supports_parallel_function_calling: bool = Field(default=False)
supports_vision: bool = Field(default=False) supports_vision: bool = Field(default=False)
supports_function_calling: bool = Field(default=False) supports_function_calling: bool = Field(default=False)

View file

@ -340,14 +340,15 @@ def function_setup(
) )
try: try:
global callback_list, add_breadcrumb, user_logger_fn, Logging global callback_list, add_breadcrumb, user_logger_fn, Logging
function_id = kwargs["id"] if "id" in kwargs else None function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0: if len(litellm.callbacks) > 0:
for callback in litellm.callbacks: for callback in litellm.callbacks:
# check if callback is a string - e.g. "lago", "openmeter" # check if callback is a string - e.g. "lago", "openmeter"
if isinstance(callback, str): if isinstance(callback, str):
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
callback callback, internal_usage_cache=None, llm_router=None
) )
if any( if any(
isinstance(cb, type(callback)) isinstance(cb, type(callback))
@ -3895,12 +3896,16 @@ def get_formatted_prompt(
def get_response_string(response_obj: ModelResponse) -> str: def get_response_string(response_obj: ModelResponse) -> str:
_choices: List[Choices] = response_obj.choices # type: ignore _choices: List[Union[Choices, StreamingChoices]] = response_obj.choices
response_str = "" response_str = ""
for choice in _choices: for choice in _choices:
if choice.message.content is not None: if isinstance(choice, Choices):
response_str += choice.message.content if choice.message.content is not None:
response_str += choice.message.content
elif isinstance(choice, StreamingChoices):
if choice.delta.content is not None:
response_str += choice.delta.content
return response_str return response_str