mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge branch 'main' into litellm_fix_in_mem_usage
This commit is contained in:
commit
8e3a073323
17 changed files with 1332 additions and 126 deletions
|
@ -152,3 +152,104 @@ litellm_remaining_team_budget_metric{team_alias="QA Prod Bot",team_id="de35b29e-
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Dynamic TPM Allocation
|
||||||
|
|
||||||
|
Prevent projects from gobbling too much quota.
|
||||||
|
|
||||||
|
Dynamically allocate TPM quota to api keys, based on active keys in that minute.
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: my-fake-model
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
api_key: my-fake-key
|
||||||
|
mock_response: hello-world
|
||||||
|
tpm: 60
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["dynamic_rate_limiter"]
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234 # OR set `LITELLM_MASTER_KEY=".."` in your .env
|
||||||
|
database_url: postgres://.. # OR set `DATABASE_URL=".."` in your .env
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
- Run 2 concurrent teams calling same model
|
||||||
|
- model has 60 TPM
|
||||||
|
- Mock response returns 30 total tokens / request
|
||||||
|
- Each team will only be able to make 1 request per minute
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
- Run 2 concurrent teams calling same model
|
||||||
|
- model has 60 TPM
|
||||||
|
- Mock response returns 30 total tokens / request
|
||||||
|
- Each team will only be able to make 1 request per minute
|
||||||
|
"""
|
||||||
|
import requests
|
||||||
|
from openai import OpenAI, RateLimitError
|
||||||
|
|
||||||
|
def create_key(api_key: str, base_url: str):
|
||||||
|
response = requests.post(
|
||||||
|
url="{}/key/generate".format(base_url),
|
||||||
|
json={},
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer {}".format(api_key)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_response = response.json()
|
||||||
|
|
||||||
|
return _response["key"]
|
||||||
|
|
||||||
|
key_1 = create_key(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
key_2 = create_key(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
# call proxy with key 1 - works
|
||||||
|
openai_client_1 = OpenAI(api_key=key_1, base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
response = openai_client_1.chat.completions.with_raw_response.create(
|
||||||
|
model="my-fake-model", messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Headers for call 1 - {}".format(response.headers))
|
||||||
|
_response = response.parse()
|
||||||
|
print("Total tokens for call - {}".format(_response.usage.total_tokens))
|
||||||
|
|
||||||
|
|
||||||
|
# call proxy with key 2 - works
|
||||||
|
openai_client_2 = OpenAI(api_key=key_2, base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
response = openai_client_2.chat.completions.with_raw_response.create(
|
||||||
|
model="my-fake-model", messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Headers for call 2 - {}".format(response.headers))
|
||||||
|
_response = response.parse()
|
||||||
|
print("Total tokens for call - {}".format(_response.usage.total_tokens))
|
||||||
|
# call proxy with key 2 - fails
|
||||||
|
try:
|
||||||
|
openai_client_2.chat.completions.with_raw_response.create(model="my-fake-model", messages=[{"role": "user", "content": "Hey, how's it going?"}])
|
||||||
|
raise Exception("This should have failed!")
|
||||||
|
except RateLimitError as e:
|
||||||
|
print("This was rate limited b/c - {}".format(str(e)))
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
```
|
||||||
|
This was rate limited b/c - Error code: 429 - {'error': {'message': {'error': 'Key=<hashed_token> over available TPM=0. Model TPM=0, Active keys=2'}, 'type': 'None', 'param': 'None', 'code': 429}}
|
||||||
|
```
|
|
@ -37,7 +37,9 @@ input_callback: List[Union[str, Callable]] = []
|
||||||
success_callback: List[Union[str, Callable]] = []
|
success_callback: List[Union[str, Callable]] = []
|
||||||
failure_callback: List[Union[str, Callable]] = []
|
failure_callback: List[Union[str, Callable]] = []
|
||||||
service_callback: List[Union[str, Callable]] = []
|
service_callback: List[Union[str, Callable]] = []
|
||||||
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter", "logfire"]
|
_custom_logger_compatible_callbacks_literal = Literal[
|
||||||
|
"lago", "openmeter", "logfire", "dynamic_rate_limiter"
|
||||||
|
]
|
||||||
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
||||||
_langfuse_default_tags: Optional[
|
_langfuse_default_tags: Optional[
|
||||||
List[
|
List[
|
||||||
|
@ -735,6 +737,7 @@ from .utils import (
|
||||||
client,
|
client,
|
||||||
exception_type,
|
exception_type,
|
||||||
get_optional_params,
|
get_optional_params,
|
||||||
|
get_response_string,
|
||||||
modify_integration,
|
modify_integration,
|
||||||
token_counter,
|
token_counter,
|
||||||
create_pretrained_tokenizer,
|
create_pretrained_tokenizer,
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from .types.services import ServiceTypes, ServiceLoggerPayload
|
|
||||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
|
||||||
from .integrations.custom_logger import CustomLogger
|
from .integrations.custom_logger import CustomLogger
|
||||||
from datetime import timedelta
|
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||||
from typing import Union, Optional, TYPE_CHECKING, Any
|
from .types.services import ServiceLoggerPayload, ServiceTypes
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -53,8 +54,8 @@ class ServiceLogging(CustomLogger):
|
||||||
call_type: str,
|
call_type: str,
|
||||||
duration: float,
|
duration: float,
|
||||||
parent_otel_span: Optional[Span] = None,
|
parent_otel_span: Optional[Span] = None,
|
||||||
start_time: Optional[datetime] = None,
|
start_time: Optional[Union[datetime, float]] = None,
|
||||||
end_time: Optional[datetime] = None,
|
end_time: Optional[Union[datetime, float]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
- For counting if the redis, postgres call is successful
|
- For counting if the redis, postgres call is successful
|
||||||
|
@ -92,8 +93,8 @@ class ServiceLogging(CustomLogger):
|
||||||
error: Union[str, Exception],
|
error: Union[str, Exception],
|
||||||
call_type: str,
|
call_type: str,
|
||||||
parent_otel_span: Optional[Span] = None,
|
parent_otel_span: Optional[Span] = None,
|
||||||
start_time: Optional[datetime] = None,
|
start_time: Optional[Union[datetime, float]] = None,
|
||||||
end_time: Optional[datetime] = None,
|
end_time: Optional[Union[float, datetime]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
- For counting if the redis, postgres call is unsuccessful
|
- For counting if the redis, postgres call is unsuccessful
|
||||||
|
|
|
@ -14,6 +14,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import timedelta
|
||||||
from typing import Any, BinaryIO, List, Literal, Optional, Union
|
from typing import Any, BinaryIO, List, Literal, Optional, Union
|
||||||
|
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
|
@ -92,9 +93,22 @@ class InMemoryCache(BaseCache):
|
||||||
else:
|
else:
|
||||||
self.set_cache(key=cache_key, value=cache_value)
|
self.set_cache(key=cache_key, value=cache_value)
|
||||||
|
|
||||||
|
|
||||||
if time.time() - self.last_cleaned > self.default_ttl:
|
if time.time() - self.last_cleaned > self.default_ttl:
|
||||||
asyncio.create_task(self.clean_up_in_memory_cache())
|
asyncio.create_task(self.clean_up_in_memory_cache())
|
||||||
|
|
||||||
|
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
|
||||||
|
"""
|
||||||
|
Add value to set
|
||||||
|
"""
|
||||||
|
# get the value
|
||||||
|
init_value = self.get_cache(key=key) or set()
|
||||||
|
for val in value:
|
||||||
|
init_value.add(val)
|
||||||
|
self.set_cache(key, init_value, ttl=ttl)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
if key in self.cache_dict:
|
if key in self.cache_dict:
|
||||||
if key in self.ttl_dict:
|
if key in self.ttl_dict:
|
||||||
|
@ -363,6 +377,7 @@ class RedisCache(BaseCache):
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||||
|
call_type="async_set_cache",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# NON blocking - notify users Redis is throwing an exception
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
|
@ -482,6 +497,80 @@ class RedisCache(BaseCache):
|
||||||
cache_value,
|
cache_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_set_cache_sadd(
|
||||||
|
self, key, value: List, ttl: Optional[float], **kwargs
|
||||||
|
):
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
_redis_client = self.init_async_client()
|
||||||
|
except Exception as e:
|
||||||
|
end_time = time.time()
|
||||||
|
_duration = end_time - start_time
|
||||||
|
asyncio.create_task(
|
||||||
|
self.service_logger_obj.async_service_failure_hook(
|
||||||
|
service=ServiceTypes.REDIS,
|
||||||
|
duration=_duration,
|
||||||
|
error=e,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||||
|
call_type="async_set_cache_sadd",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
|
verbose_logger.error(
|
||||||
|
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
|
||||||
|
str(e),
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
key = self.check_and_fix_namespace(key=key)
|
||||||
|
async with _redis_client as redis_client:
|
||||||
|
print_verbose(
|
||||||
|
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await redis_client.sadd(key, *value)
|
||||||
|
if ttl is not None:
|
||||||
|
_td = timedelta(seconds=ttl)
|
||||||
|
await redis_client.expire(key, _td)
|
||||||
|
print_verbose(
|
||||||
|
f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
_duration = end_time - start_time
|
||||||
|
asyncio.create_task(
|
||||||
|
self.service_logger_obj.async_service_success_hook(
|
||||||
|
service=ServiceTypes.REDIS,
|
||||||
|
duration=_duration,
|
||||||
|
call_type="async_set_cache_sadd",
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
end_time = time.time()
|
||||||
|
_duration = end_time - start_time
|
||||||
|
asyncio.create_task(
|
||||||
|
self.service_logger_obj.async_service_failure_hook(
|
||||||
|
service=ServiceTypes.REDIS,
|
||||||
|
duration=_duration,
|
||||||
|
error=e,
|
||||||
|
call_type="async_set_cache_sadd",
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
|
verbose_logger.error(
|
||||||
|
"LiteLLM Redis Caching: async set_cache_sadd() - Got exception from REDIS %s, Writing value=%s",
|
||||||
|
str(e),
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
|
||||||
async def batch_cache_write(self, key, value, **kwargs):
|
async def batch_cache_write(self, key, value, **kwargs):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
|
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
|
||||||
|
@ -1506,7 +1595,7 @@ class DualCache(BaseCache):
|
||||||
key, value, **kwargs
|
key, value, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.redis_cache is not None and local_only == False:
|
if self.redis_cache is not None and local_only is False:
|
||||||
result = await self.redis_cache.async_increment(key, value, **kwargs)
|
result = await self.redis_cache.async_increment(key, value, **kwargs)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
@ -1515,6 +1604,38 @@ class DualCache(BaseCache):
|
||||||
verbose_logger.debug(traceback.format_exc())
|
verbose_logger.debug(traceback.format_exc())
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def async_set_cache_sadd(
|
||||||
|
self, key, value: List, local_only: bool = False, **kwargs
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add value to a set
|
||||||
|
|
||||||
|
Key - the key in cache
|
||||||
|
|
||||||
|
Value - str - the value you want to add to the set
|
||||||
|
|
||||||
|
Returns - None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
_ = await self.in_memory_cache.async_set_cache_sadd(
|
||||||
|
key, value, ttl=kwargs.get("ttl", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.redis_cache is not None and local_only is False:
|
||||||
|
_ = await self.redis_cache.async_set_cache_sadd(
|
||||||
|
key, value, ttl=kwargs.get("ttl", None) ** kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"LiteLLM Cache: Excepton async set_cache_sadd: {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
self.in_memory_cache.flush_cache()
|
self.in_memory_cache.flush_cache()
|
||||||
|
|
|
@ -105,8 +105,8 @@ class OpenTelemetry(CustomLogger):
|
||||||
self,
|
self,
|
||||||
payload: ServiceLoggerPayload,
|
payload: ServiceLoggerPayload,
|
||||||
parent_otel_span: Optional[Span] = None,
|
parent_otel_span: Optional[Span] = None,
|
||||||
start_time: Optional[datetime] = None,
|
start_time: Optional[Union[datetime, float]] = None,
|
||||||
end_time: Optional[datetime] = None,
|
end_time: Optional[Union[datetime, float]] = None,
|
||||||
):
|
):
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
@ -144,8 +144,8 @@ class OpenTelemetry(CustomLogger):
|
||||||
self,
|
self,
|
||||||
payload: ServiceLoggerPayload,
|
payload: ServiceLoggerPayload,
|
||||||
parent_otel_span: Optional[Span] = None,
|
parent_otel_span: Optional[Span] = None,
|
||||||
start_time: Optional[datetime] = None,
|
start_time: Optional[Union[datetime, float]] = None,
|
||||||
end_time: Optional[datetime] = None,
|
end_time: Optional[Union[float, datetime]] = None,
|
||||||
):
|
):
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,8 @@ from litellm import (
|
||||||
turn_off_message_logging,
|
turn_off_message_logging,
|
||||||
verbose_logger,
|
verbose_logger,
|
||||||
)
|
)
|
||||||
from litellm.caching import InMemoryCache, S3Cache
|
|
||||||
|
from litellm.caching import InMemoryCache, S3Cache, DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.redact_messages import (
|
from litellm.litellm_core_utils.redact_messages import (
|
||||||
redact_message_input_output_from_logging,
|
redact_message_input_output_from_logging,
|
||||||
|
@ -1899,7 +1900,11 @@ def set_callbacks(callback_list, function_id=None):
|
||||||
|
|
||||||
def _init_custom_logger_compatible_class(
|
def _init_custom_logger_compatible_class(
|
||||||
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
|
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
|
||||||
) -> Callable:
|
internal_usage_cache: Optional[DualCache],
|
||||||
|
llm_router: Optional[
|
||||||
|
Any
|
||||||
|
], # expect litellm.Router, but typing errors due to circular import
|
||||||
|
) -> CustomLogger:
|
||||||
if logging_integration == "lago":
|
if logging_integration == "lago":
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, LagoLogger):
|
if isinstance(callback, LagoLogger):
|
||||||
|
@ -1935,3 +1940,58 @@ def _init_custom_logger_compatible_class(
|
||||||
_otel_logger = OpenTelemetry(config=otel_config)
|
_otel_logger = OpenTelemetry(config=otel_config)
|
||||||
_in_memory_loggers.append(_otel_logger)
|
_in_memory_loggers.append(_otel_logger)
|
||||||
return _otel_logger # type: ignore
|
return _otel_logger # type: ignore
|
||||||
|
elif logging_integration == "dynamic_rate_limiter":
|
||||||
|
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
||||||
|
_PROXY_DynamicRateLimitHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
|
||||||
|
return callback # type: ignore
|
||||||
|
|
||||||
|
if internal_usage_cache is None:
|
||||||
|
raise Exception(
|
||||||
|
"Internal Error: Cache cannot be empty - internal_usage_cache={}".format(
|
||||||
|
internal_usage_cache
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dynamic_rate_limiter_obj = _PROXY_DynamicRateLimitHandler(
|
||||||
|
internal_usage_cache=internal_usage_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
if llm_router is not None and isinstance(llm_router, litellm.Router):
|
||||||
|
dynamic_rate_limiter_obj.update_variables(llm_router=llm_router)
|
||||||
|
_in_memory_loggers.append(dynamic_rate_limiter_obj)
|
||||||
|
return dynamic_rate_limiter_obj # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_logger_compatible_class(
|
||||||
|
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
|
||||||
|
) -> Optional[CustomLogger]:
|
||||||
|
if logging_integration == "lago":
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, LagoLogger):
|
||||||
|
return callback
|
||||||
|
elif logging_integration == "openmeter":
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, OpenMeterLogger):
|
||||||
|
return callback
|
||||||
|
elif logging_integration == "logfire":
|
||||||
|
if "LOGFIRE_TOKEN" not in os.environ:
|
||||||
|
raise ValueError("LOGFIRE_TOKEN not found in environment variables")
|
||||||
|
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||||
|
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, OpenTelemetry):
|
||||||
|
return callback # type: ignore
|
||||||
|
|
||||||
|
elif logging_integration == "dynamic_rate_limiter":
|
||||||
|
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
||||||
|
_PROXY_DynamicRateLimitHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
|
||||||
|
return callback # type: ignore
|
||||||
|
return None
|
||||||
|
|
|
@ -428,7 +428,7 @@ def mock_completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: List,
|
messages: List,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
mock_response: Union[str, Exception] = "This is a mock request",
|
mock_response: Union[str, Exception, dict] = "This is a mock request",
|
||||||
mock_tool_calls: Optional[List] = None,
|
mock_tool_calls: Optional[List] = None,
|
||||||
logging=None,
|
logging=None,
|
||||||
custom_llm_provider=None,
|
custom_llm_provider=None,
|
||||||
|
@ -477,6 +477,9 @@ def mock_completion(
|
||||||
if time_delay is not None:
|
if time_delay is not None:
|
||||||
time.sleep(time_delay)
|
time.sleep(time_delay)
|
||||||
|
|
||||||
|
if isinstance(mock_response, dict):
|
||||||
|
return ModelResponse(**mock_response)
|
||||||
|
|
||||||
model_response = ModelResponse(stream=stream)
|
model_response = ModelResponse(stream=stream)
|
||||||
if stream is True:
|
if stream is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
|
|
@ -1,61 +1,10 @@
|
||||||
environment_variables:
|
|
||||||
LANGFUSE_PUBLIC_KEY: Q6K8MQN6L7sPYSJiFKM9eNrETOx6V/FxVPup4FqdKsZK1hyR4gyanlQ2KHLg5D5afng99uIt0JCEQ2jiKF9UxFvtnb4BbJ4qpeceH+iK8v/bdg==
|
|
||||||
LANGFUSE_SECRET_KEY: 5xQ7KMa6YMLsm+H/Pf1VmlqWq1NON5IoCxABhkUBeSck7ftsj2CmpkL2ZwrxwrktgiTUBH+3gJYBX+XBk7lqOOUpvmiLjol/E5lCqq0M1CqLWA==
|
|
||||||
SLACK_WEBHOOK_URL: RJjhS0Hhz0/s07sCIf1OTXmTGodpK9L2K9p953Z+fOX0l2SkPFT6mB9+yIrLufmlwEaku5NNEBKy//+AG01yOd+7wV1GhK65vfj3B/gTN8t5cuVnR4vFxKY5Rx4eSGLtzyAs+aIBTp4GoNXDIjroCqfCjPkItEZWCg==
|
|
||||||
general_settings:
|
|
||||||
alerting:
|
|
||||||
- slack
|
|
||||||
alerting_threshold: 300
|
|
||||||
database_connection_pool_limit: 100
|
|
||||||
database_connection_timeout: 60
|
|
||||||
disable_master_key_return: true
|
|
||||||
health_check_interval: 300
|
|
||||||
proxy_batch_write_at: 60
|
|
||||||
ui_access_mode: all
|
|
||||||
# master_key: sk-1234
|
|
||||||
litellm_settings:
|
|
||||||
allowed_fails: 3
|
|
||||||
failure_callback:
|
|
||||||
- prometheus
|
|
||||||
num_retries: 3
|
|
||||||
service_callback:
|
|
||||||
- prometheus_system
|
|
||||||
success_callback:
|
|
||||||
- langfuse
|
|
||||||
- prometheus
|
|
||||||
- langsmith
|
|
||||||
model_list:
|
model_list:
|
||||||
- litellm_params:
|
- model_name: my-fake-model
|
||||||
model: gpt-3.5-turbo
|
litellm_params:
|
||||||
model_name: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
- litellm_params:
|
api_key: my-fake-key
|
||||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
mock_response: hello-world
|
||||||
api_key: my-fake-key
|
tpm: 60
|
||||||
model: openai/my-fake-model
|
|
||||||
stream_timeout: 0.001
|
litellm_settings:
|
||||||
model_name: fake-openai-endpoint
|
callbacks: ["dynamic_rate_limiter"]
|
||||||
- litellm_params:
|
|
||||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
|
||||||
api_key: my-fake-key
|
|
||||||
model: openai/my-fake-model-2
|
|
||||||
stream_timeout: 0.001
|
|
||||||
model_name: fake-openai-endpoint
|
|
||||||
- litellm_params:
|
|
||||||
api_base: os.environ/AZURE_API_BASE
|
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
api_version: 2023-07-01-preview
|
|
||||||
model: azure/chatgpt-v-2
|
|
||||||
stream_timeout: 0.001
|
|
||||||
model_name: azure-gpt-3.5
|
|
||||||
- litellm_params:
|
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
|
||||||
model: text-embedding-ada-002
|
|
||||||
model_name: text-embedding-ada-002
|
|
||||||
- litellm_params:
|
|
||||||
model: text-completion-openai/gpt-3.5-turbo-instruct
|
|
||||||
model_name: gpt-instruct
|
|
||||||
router_settings:
|
|
||||||
enable_pre_call_checks: true
|
|
||||||
redis_host: os.environ/REDIS_HOST
|
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
|
||||||
redis_port: os.environ/REDIS_PORT
|
|
|
@ -30,6 +30,7 @@ model_list:
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_version: 2024-02-15-preview
|
api_version: 2024-02-15-preview
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
|
tpm: 100
|
||||||
model_name: gpt-3.5-turbo
|
model_name: gpt-3.5-turbo
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
|
@ -40,6 +41,7 @@ model_list:
|
||||||
api_version: 2024-02-15-preview
|
api_version: 2024-02-15-preview
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
drop_params: True
|
drop_params: True
|
||||||
|
tpm: 100
|
||||||
model_name: gpt-3.5-turbo
|
model_name: gpt-3.5-turbo
|
||||||
- model_name: tts
|
- model_name: tts
|
||||||
litellm_params:
|
litellm_params:
|
||||||
|
@ -67,8 +69,7 @@ model_list:
|
||||||
max_input_tokens: 80920
|
max_input_tokens: 80920
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["langfuse"]
|
callbacks: ["dynamic_rate_limiter"]
|
||||||
failure_callback: ["langfuse"]
|
|
||||||
# default_team_settings:
|
# default_team_settings:
|
||||||
# - team_id: proj1
|
# - team_id: proj1
|
||||||
# success_callback: ["langfuse"]
|
# success_callback: ["langfuse"]
|
||||||
|
|
205
litellm/proxy/hooks/dynamic_rate_limiter.py
Normal file
205
litellm/proxy/hooks/dynamic_rate_limiter.py
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
# What is this?
|
||||||
|
## Allocates dynamic tpm/rpm quota for a project based on current traffic
|
||||||
|
## Tracks num active projects per minute
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import ModelResponse, Router
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.types.router import ModelGroupInfo
|
||||||
|
from litellm.utils import get_utc_datetime
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicRateLimiterCache:
|
||||||
|
"""
|
||||||
|
Thin wrapper on DualCache for this file.
|
||||||
|
|
||||||
|
Track number of active projects calling a model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cache: DualCache) -> None:
|
||||||
|
self.cache = cache
|
||||||
|
self.ttl = 60 # 1 min ttl
|
||||||
|
|
||||||
|
async def async_get_cache(self, model: str) -> Optional[int]:
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
key_name = "{}:{}".format(current_minute, model)
|
||||||
|
_response = await self.cache.async_get_cache(key=key_name)
|
||||||
|
response: Optional[int] = None
|
||||||
|
if _response is not None:
|
||||||
|
response = len(_response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def async_set_cache_sadd(self, model: str, value: List):
|
||||||
|
"""
|
||||||
|
Add value to set.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- model: str, the name of the model group
|
||||||
|
- value: str, the team id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- Exception, if unable to connect to cache client (if redis caching enabled)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
|
||||||
|
key_name = "{}:{}".format(current_minute, model)
|
||||||
|
await self.cache.async_set_cache_sadd(
|
||||||
|
key=key_name, value=value, ttl=self.ttl
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||||
|
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self, internal_usage_cache: DualCache):
|
||||||
|
self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache)
|
||||||
|
|
||||||
|
def update_variables(self, llm_router: Router):
|
||||||
|
self.llm_router = llm_router
|
||||||
|
|
||||||
|
async def check_available_tpm(
|
||||||
|
self, model: str
|
||||||
|
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
||||||
|
"""
|
||||||
|
For a given model, get its available tpm
|
||||||
|
|
||||||
|
Returns
|
||||||
|
- Tuple[available_tpm, model_tpm, active_projects]
|
||||||
|
- available_tpm: int or null - always 0 or positive.
|
||||||
|
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
|
||||||
|
- active_projects: int or null
|
||||||
|
"""
|
||||||
|
active_projects = await self.internal_usage_cache.async_get_cache(model=model)
|
||||||
|
current_model_tpm: Optional[int] = await self.llm_router.get_model_group_usage(
|
||||||
|
model_group=model
|
||||||
|
)
|
||||||
|
model_group_info: Optional[ModelGroupInfo] = (
|
||||||
|
self.llm_router.get_model_group_info(model_group=model)
|
||||||
|
)
|
||||||
|
total_model_tpm: Optional[int] = None
|
||||||
|
if model_group_info is not None and model_group_info.tpm is not None:
|
||||||
|
total_model_tpm = model_group_info.tpm
|
||||||
|
|
||||||
|
remaining_model_tpm: Optional[int] = None
|
||||||
|
if total_model_tpm is not None and current_model_tpm is not None:
|
||||||
|
remaining_model_tpm = total_model_tpm - current_model_tpm
|
||||||
|
elif total_model_tpm is not None:
|
||||||
|
remaining_model_tpm = total_model_tpm
|
||||||
|
|
||||||
|
available_tpm: Optional[int] = None
|
||||||
|
|
||||||
|
if remaining_model_tpm is not None:
|
||||||
|
if active_projects is not None:
|
||||||
|
available_tpm = int(remaining_model_tpm / active_projects)
|
||||||
|
else:
|
||||||
|
available_tpm = remaining_model_tpm
|
||||||
|
|
||||||
|
if available_tpm is not None and available_tpm < 0:
|
||||||
|
available_tpm = 0
|
||||||
|
return available_tpm, remaining_model_tpm, active_projects
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: Literal[
|
||||||
|
"completion",
|
||||||
|
"text_completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
],
|
||||||
|
) -> Optional[
|
||||||
|
Union[Exception, str, dict]
|
||||||
|
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
|
||||||
|
"""
|
||||||
|
- For a model group
|
||||||
|
- Check if tpm available
|
||||||
|
- Raise RateLimitError if no tpm available
|
||||||
|
"""
|
||||||
|
if "model" in data:
|
||||||
|
available_tpm, model_tpm, active_projects = await self.check_available_tpm(
|
||||||
|
model=data["model"]
|
||||||
|
)
|
||||||
|
if available_tpm is not None and available_tpm == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail={
|
||||||
|
"error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format(
|
||||||
|
user_api_key_dict.api_key,
|
||||||
|
available_tpm,
|
||||||
|
model_tpm,
|
||||||
|
active_projects,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif available_tpm is not None:
|
||||||
|
## UPDATE CACHE WITH ACTIVE PROJECT
|
||||||
|
asyncio.create_task(
|
||||||
|
self.internal_usage_cache.async_set_cache_sadd( # this is a set
|
||||||
|
model=data["model"], # type: ignore
|
||||||
|
value=[user_api_key_dict.token or "default_key"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self, user_api_key_dict: UserAPIKeyAuth, response
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if isinstance(response, ModelResponse):
|
||||||
|
model_info = self.llm_router.get_model_info(
|
||||||
|
id=response._hidden_params["model_id"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
model_info is not None
|
||||||
|
), "Model info for model with id={} is None".format(
|
||||||
|
response._hidden_params["model_id"]
|
||||||
|
)
|
||||||
|
available_tpm, remaining_model_tpm, active_projects = (
|
||||||
|
await self.check_available_tpm(model=model_info["model_name"])
|
||||||
|
)
|
||||||
|
response._hidden_params["additional_headers"] = {
|
||||||
|
"x-litellm-model_group": model_info["model_name"],
|
||||||
|
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
|
||||||
|
"x-ratelimit-remaining-model-tokens": remaining_model_tpm,
|
||||||
|
"x-ratelimit-current-active-projects": active_projects,
|
||||||
|
}
|
||||||
|
|
||||||
|
return response
|
||||||
|
return await super().async_post_call_success_hook(
|
||||||
|
user_api_key_dict, response
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return response
|
|
@ -433,6 +433,7 @@ def get_custom_headers(
|
||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
model_region: Optional[str] = None,
|
model_region: Optional[str] = None,
|
||||||
fastest_response_batch_completion: Optional[bool] = None,
|
fastest_response_batch_completion: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
exclude_values = {"", None}
|
exclude_values = {"", None}
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -448,6 +449,7 @@ def get_custom_headers(
|
||||||
if fastest_response_batch_completion is not None
|
if fastest_response_batch_completion is not None
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
**{k: str(v) for k, v in kwargs.items()},
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
|
@ -2644,7 +2646,9 @@ async def startup_event():
|
||||||
redis_cache=redis_usage_cache
|
redis_cache=redis_usage_cache
|
||||||
) # used by parallel request limiter for rate limiting keys across instances
|
) # used by parallel request limiter for rate limiting keys across instances
|
||||||
|
|
||||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
proxy_logging_obj._init_litellm_callbacks(
|
||||||
|
llm_router=llm_router
|
||||||
|
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||||
|
|
||||||
if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types:
|
if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
@ -3061,6 +3065,14 @@ async def chat_completion(
|
||||||
headers=custom_headers,
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify outgoing data
|
||||||
|
response = await proxy_logging_obj.post_call_success_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, response=response
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
get_custom_headers(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
@ -3070,14 +3082,10 @@ async def chat_completion(
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||||
|
**additional_headers,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
### CALL HOOKS ### - modify outgoing data
|
|
||||||
response = await proxy_logging_obj.post_call_success_hook(
|
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except RejectedRequestError as e:
|
except RejectedRequestError as e:
|
||||||
_data = e.request_data
|
_data = e.request_data
|
||||||
|
@ -3116,11 +3124,10 @@ async def chat_completion(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}".format(
|
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format(
|
||||||
get_error_message_str(e=e)
|
get_error_message_str(e=e), traceback.format_exc()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
|
|
|
@ -229,31 +229,32 @@ class ProxyLogging:
|
||||||
if redis_cache is not None:
|
if redis_cache is not None:
|
||||||
self.internal_usage_cache.redis_cache = redis_cache
|
self.internal_usage_cache.redis_cache = redis_cache
|
||||||
|
|
||||||
def _init_litellm_callbacks(self):
|
def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None):
|
||||||
print_verbose("INITIALIZING LITELLM CALLBACKS!")
|
|
||||||
self.service_logging_obj = ServiceLogging()
|
self.service_logging_obj = ServiceLogging()
|
||||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore
|
||||||
litellm.callbacks.append(self.max_budget_limiter)
|
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
|
||||||
litellm.callbacks.append(self.cache_control_check)
|
litellm.callbacks.append(self.cache_control_check) # type: ignore
|
||||||
litellm.callbacks.append(self.service_logging_obj)
|
litellm.callbacks.append(self.service_logging_obj) # type: ignore
|
||||||
litellm.success_callback.append(
|
litellm.success_callback.append(
|
||||||
self.slack_alerting_instance.response_taking_too_long_callback
|
self.slack_alerting_instance.response_taking_too_long_callback
|
||||||
)
|
)
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if isinstance(callback, str):
|
if isinstance(callback, str):
|
||||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class(
|
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||||
callback
|
callback,
|
||||||
|
internal_usage_cache=self.internal_usage_cache,
|
||||||
|
llm_router=llm_router,
|
||||||
)
|
)
|
||||||
if callback not in litellm.input_callback:
|
if callback not in litellm.input_callback:
|
||||||
litellm.input_callback.append(callback)
|
litellm.input_callback.append(callback) # type: ignore
|
||||||
if callback not in litellm.success_callback:
|
if callback not in litellm.success_callback:
|
||||||
litellm.success_callback.append(callback)
|
litellm.success_callback.append(callback) # type: ignore
|
||||||
if callback not in litellm.failure_callback:
|
if callback not in litellm.failure_callback:
|
||||||
litellm.failure_callback.append(callback)
|
litellm.failure_callback.append(callback) # type: ignore
|
||||||
if callback not in litellm._async_success_callback:
|
if callback not in litellm._async_success_callback:
|
||||||
litellm._async_success_callback.append(callback)
|
litellm._async_success_callback.append(callback) # type: ignore
|
||||||
if callback not in litellm._async_failure_callback:
|
if callback not in litellm._async_failure_callback:
|
||||||
litellm._async_failure_callback.append(callback)
|
litellm._async_failure_callback.append(callback) # type: ignore
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(litellm.input_callback) > 0
|
len(litellm.input_callback) > 0
|
||||||
|
@ -301,10 +302,19 @@ class ProxyLogging:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
|
_callback: Optional[CustomLogger] = None
|
||||||
callback.__class__
|
if isinstance(callback, str):
|
||||||
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||||
|
callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_callback = callback # type: ignore
|
||||||
|
if (
|
||||||
|
_callback is not None
|
||||||
|
and isinstance(_callback, CustomLogger)
|
||||||
|
and "async_pre_call_hook" in vars(_callback.__class__)
|
||||||
):
|
):
|
||||||
response = await callback.async_pre_call_hook(
|
response = await _callback.async_pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
cache=self.call_details["user_api_key_cache"],
|
cache=self.call_details["user_api_key_cache"],
|
||||||
data=data,
|
data=data,
|
||||||
|
@ -574,8 +584,15 @@ class ProxyLogging:
|
||||||
|
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger):
|
_callback: Optional[CustomLogger] = None
|
||||||
await callback.async_post_call_failure_hook(
|
if isinstance(callback, str):
|
||||||
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||||
|
callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_callback = callback # type: ignore
|
||||||
|
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||||
|
await _callback.async_post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
original_exception=original_exception,
|
original_exception=original_exception,
|
||||||
)
|
)
|
||||||
|
@ -596,8 +613,15 @@ class ProxyLogging:
|
||||||
"""
|
"""
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger):
|
_callback: Optional[CustomLogger] = None
|
||||||
await callback.async_post_call_success_hook(
|
if isinstance(callback, str):
|
||||||
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||||
|
callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_callback = callback # type: ignore
|
||||||
|
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||||
|
await _callback.async_post_call_success_hook(
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
user_api_key_dict=user_api_key_dict, response=response
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -615,14 +639,25 @@ class ProxyLogging:
|
||||||
Covers:
|
Covers:
|
||||||
1. /chat/completions
|
1. /chat/completions
|
||||||
"""
|
"""
|
||||||
for callback in litellm.callbacks:
|
response_str: Optional[str] = None
|
||||||
try:
|
if isinstance(response, ModelResponse):
|
||||||
if isinstance(callback, CustomLogger):
|
response_str = litellm.get_response_string(response_obj=response)
|
||||||
await callback.async_post_call_streaming_hook(
|
if response_str is not None:
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
for callback in litellm.callbacks:
|
||||||
)
|
try:
|
||||||
except Exception as e:
|
_callback: Optional[CustomLogger] = None
|
||||||
raise e
|
if isinstance(callback, str):
|
||||||
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||||
|
callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_callback = callback # type: ignore
|
||||||
|
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||||
|
await _callback.async_post_call_streaming_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, response=response_str
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def post_call_streaming_hook(
|
async def post_call_streaming_hook(
|
||||||
|
|
|
@ -11,6 +11,7 @@ import asyncio
|
||||||
import concurrent
|
import concurrent
|
||||||
import copy
|
import copy
|
||||||
import datetime as datetime_og
|
import datetime as datetime_og
|
||||||
|
import enum
|
||||||
import hashlib
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
@ -90,6 +91,10 @@ from litellm.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RoutingArgs(enum.Enum):
|
||||||
|
ttl = 60 # 1min (RPM/TPM expire key)
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
model_names: List = []
|
model_names: List = []
|
||||||
cache_responses: Optional[bool] = False
|
cache_responses: Optional[bool] = False
|
||||||
|
@ -387,6 +392,11 @@ class Router:
|
||||||
routing_strategy=routing_strategy,
|
routing_strategy=routing_strategy,
|
||||||
routing_strategy_args=routing_strategy_args,
|
routing_strategy_args=routing_strategy_args,
|
||||||
)
|
)
|
||||||
|
## USAGE TRACKING ##
|
||||||
|
if isinstance(litellm._async_success_callback, list):
|
||||||
|
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||||
|
else:
|
||||||
|
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||||
## COOLDOWNS ##
|
## COOLDOWNS ##
|
||||||
if isinstance(litellm.failure_callback, list):
|
if isinstance(litellm.failure_callback, list):
|
||||||
litellm.failure_callback.append(self.deployment_callback_on_failure)
|
litellm.failure_callback.append(self.deployment_callback_on_failure)
|
||||||
|
@ -2640,13 +2650,69 @@ class Router:
|
||||||
time.sleep(_timeout)
|
time.sleep(_timeout)
|
||||||
|
|
||||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||||
original_exception.max_retries = num_retries
|
setattr(original_exception, "max_retries", num_retries)
|
||||||
original_exception.num_retries = current_attempt
|
setattr(original_exception, "num_retries", current_attempt)
|
||||||
|
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
### HELPER FUNCTIONS
|
### HELPER FUNCTIONS
|
||||||
|
|
||||||
|
async def deployment_callback_on_success(
|
||||||
|
self,
|
||||||
|
kwargs, # kwargs to completion
|
||||||
|
completion_response, # response from completion
|
||||||
|
start_time,
|
||||||
|
end_time, # start/end time
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Track remaining tpm/rpm quota for model in model_list
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
Update TPM usage on success
|
||||||
|
"""
|
||||||
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"model_group", None
|
||||||
|
)
|
||||||
|
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
|
return
|
||||||
|
elif isinstance(id, int):
|
||||||
|
id = str(id)
|
||||||
|
|
||||||
|
total_tokens = completion_response["usage"]["total_tokens"]
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime(
|
||||||
|
"%H-%M"
|
||||||
|
) # use the same timezone regardless of system clock
|
||||||
|
|
||||||
|
tpm_key = f"global_router:{id}:tpm:{current_minute}"
|
||||||
|
# ------------
|
||||||
|
# Update usage
|
||||||
|
# ------------
|
||||||
|
# update cache
|
||||||
|
|
||||||
|
## TPM
|
||||||
|
await self.cache.async_increment_cache(
|
||||||
|
key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.error(
|
||||||
|
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
def deployment_callback_on_failure(
|
def deployment_callback_on_failure(
|
||||||
self,
|
self,
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
|
@ -3812,10 +3878,39 @@ class Router:
|
||||||
|
|
||||||
model_group_info: Optional[ModelGroupInfo] = None
|
model_group_info: Optional[ModelGroupInfo] = None
|
||||||
|
|
||||||
|
total_tpm: Optional[int] = None
|
||||||
|
total_rpm: Optional[int] = None
|
||||||
|
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
if "model_name" in model and model["model_name"] == model_group:
|
if "model_name" in model and model["model_name"] == model_group:
|
||||||
# model in model group found #
|
# model in model group found #
|
||||||
litellm_params = LiteLLM_Params(**model["litellm_params"])
|
litellm_params = LiteLLM_Params(**model["litellm_params"])
|
||||||
|
# get model tpm
|
||||||
|
_deployment_tpm: Optional[int] = None
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = model.get("tpm", None)
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = model.get("litellm_params", {}).get("tpm", None)
|
||||||
|
if _deployment_tpm is None:
|
||||||
|
_deployment_tpm = model.get("model_info", {}).get("tpm", None)
|
||||||
|
|
||||||
|
if _deployment_tpm is not None:
|
||||||
|
if total_tpm is None:
|
||||||
|
total_tpm = 0
|
||||||
|
total_tpm += _deployment_tpm # type: ignore
|
||||||
|
# get model rpm
|
||||||
|
_deployment_rpm: Optional[int] = None
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = model.get("rpm", None)
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = model.get("litellm_params", {}).get("rpm", None)
|
||||||
|
if _deployment_rpm is None:
|
||||||
|
_deployment_rpm = model.get("model_info", {}).get("rpm", None)
|
||||||
|
|
||||||
|
if _deployment_rpm is not None:
|
||||||
|
if total_rpm is None:
|
||||||
|
total_rpm = 0
|
||||||
|
total_rpm += _deployment_rpm # type: ignore
|
||||||
# get model info
|
# get model info
|
||||||
try:
|
try:
|
||||||
model_info = litellm.get_model_info(model=litellm_params.model)
|
model_info = litellm.get_model_info(model=litellm_params.model)
|
||||||
|
@ -3929,8 +4024,44 @@ class Router:
|
||||||
"supported_openai_params"
|
"supported_openai_params"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
|
||||||
|
if total_tpm is not None and model_group_info is not None:
|
||||||
|
model_group_info.tpm = total_tpm
|
||||||
|
|
||||||
|
if total_rpm is not None and model_group_info is not None:
|
||||||
|
model_group_info.rpm = total_rpm
|
||||||
|
|
||||||
return model_group_info
|
return model_group_info
|
||||||
|
|
||||||
|
async def get_model_group_usage(self, model_group: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Returns remaining tpm quota for model group
|
||||||
|
"""
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime(
|
||||||
|
"%H-%M"
|
||||||
|
) # use the same timezone regardless of system clock
|
||||||
|
tpm_keys: List[str] = []
|
||||||
|
for model in self.model_list:
|
||||||
|
if "model_name" in model and model["model_name"] == model_group:
|
||||||
|
tpm_keys.append(
|
||||||
|
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
|
||||||
|
)
|
||||||
|
|
||||||
|
## TPM
|
||||||
|
tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache(
|
||||||
|
keys=tpm_keys
|
||||||
|
)
|
||||||
|
tpm_usage: Optional[int] = None
|
||||||
|
if tpm_usage_list is not None:
|
||||||
|
for t in tpm_usage_list:
|
||||||
|
if isinstance(t, int):
|
||||||
|
if tpm_usage is None:
|
||||||
|
tpm_usage = 0
|
||||||
|
tpm_usage += t
|
||||||
|
|
||||||
|
return tpm_usage
|
||||||
|
|
||||||
def get_model_ids(self) -> List[str]:
|
def get_model_ids(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns list of model id's.
|
Returns list of model id's.
|
||||||
|
@ -4858,7 +4989,7 @@ class Router:
|
||||||
def reset(self):
|
def reset(self):
|
||||||
## clean up on close
|
## clean up on close
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
litellm.__async_success_callback = []
|
litellm._async_success_callback = []
|
||||||
litellm.failure_callback = []
|
litellm.failure_callback = []
|
||||||
litellm._async_failure_callback = []
|
litellm._async_failure_callback = []
|
||||||
self.retry_policy = None
|
self.retry_policy = None
|
||||||
|
|
486
litellm/tests/test_dynamic_rate_limit_handler.py
Normal file
486
litellm/tests/test_dynamic_rate_limit_handler.py
Normal file
|
@ -0,0 +1,486 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for 'dynamic_rate_limiter.py`
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import DualCache, Router
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
||||||
|
_PROXY_DynamicRateLimitHandler as DynamicRateLimitHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Basic test cases:
|
||||||
|
|
||||||
|
- If 1 'active' project => give all tpm
|
||||||
|
- If 2 'active' projects => divide tpm in 2
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dynamic_rate_limit_handler() -> DynamicRateLimitHandler:
|
||||||
|
internal_cache = DualCache()
|
||||||
|
return DynamicRateLimitHandler(internal_usage_cache=internal_cache)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_response() -> litellm.ModelResponse:
|
||||||
|
return litellm.ModelResponse(
|
||||||
|
**{
|
||||||
|
"id": "chatcmpl-abc123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1699896916,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abc123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": '{\n"location": "Boston, MA"\n}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_api_key_auth() -> UserAPIKeyAuth:
|
||||||
|
return UserAPIKeyAuth()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_projects", [1, 2, 100])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
||||||
|
model = "my-fake-model"
|
||||||
|
## SET CACHE W/ ACTIVE PROJECTS
|
||||||
|
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
|
||||||
|
|
||||||
|
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||||
|
model=model, value=projects
|
||||||
|
)
|
||||||
|
|
||||||
|
model_tpm = 100
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
## CHECK AVAILABLE TPM PER PROJECT
|
||||||
|
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_availability = int(model_tpm / num_projects)
|
||||||
|
|
||||||
|
assert availability == expected_availability
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rate_limit_raised(dynamic_rate_limit_handler, user_api_key_auth):
|
||||||
|
"""
|
||||||
|
Unit test. Tests if rate limit error raised when quota exhausted.
|
||||||
|
"""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
model = "my-fake-model"
|
||||||
|
## SET CACHE W/ ACTIVE PROJECTS
|
||||||
|
projects = [str(uuid.uuid4())]
|
||||||
|
|
||||||
|
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||||
|
model=model, value=projects
|
||||||
|
)
|
||||||
|
|
||||||
|
model_tpm = 0
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
## CHECK AVAILABLE TPM PER PROJECT
|
||||||
|
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_availability = int(model_tpm / 1)
|
||||||
|
|
||||||
|
assert availability == expected_availability
|
||||||
|
|
||||||
|
## CHECK if exception raised
|
||||||
|
|
||||||
|
try:
|
||||||
|
await dynamic_rate_limit_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_auth,
|
||||||
|
cache=DualCache(),
|
||||||
|
data={"model": model},
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
pytest.fail("Expected this to raise HTTPexception")
|
||||||
|
except HTTPException as e:
|
||||||
|
assert e.status_code == 429 # check if rate limit error raised
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
||||||
|
"""
|
||||||
|
If just 1 active project
|
||||||
|
|
||||||
|
it should get all the quota
|
||||||
|
|
||||||
|
= allow request to go through
|
||||||
|
- update token usage
|
||||||
|
- exhaust all tpm with just 1 project
|
||||||
|
- assert ratelimiterror raised at 100%+1 tpm
|
||||||
|
"""
|
||||||
|
model = "my-fake-model"
|
||||||
|
## model tpm - 50
|
||||||
|
model_tpm = 50
|
||||||
|
## tpm per request - 10
|
||||||
|
setattr(
|
||||||
|
mock_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
prev_availability: Optional[int] = None
|
||||||
|
allowed_fails = 1
|
||||||
|
for _ in range(5):
|
||||||
|
try:
|
||||||
|
# check availability
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
## assert availability updated
|
||||||
|
if prev_availability is not None and availability is not None:
|
||||||
|
assert availability == prev_availability - 10
|
||||||
|
|
||||||
|
print(
|
||||||
|
"prev_availability={}, availability={}".format(
|
||||||
|
prev_availability, availability
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_availability = availability
|
||||||
|
|
||||||
|
# make call
|
||||||
|
await llm_router.acompletion(
|
||||||
|
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
except Exception:
|
||||||
|
if allowed_fails > 0:
|
||||||
|
allowed_fails -= 1
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_cache(
|
||||||
|
dynamic_rate_limit_handler, mock_response, user_api_key_auth
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check if active project correctly updated
|
||||||
|
"""
|
||||||
|
model = "my-fake-model"
|
||||||
|
model_tpm = 50
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
## INITIAL ACTIVE PROJECTS - ASSERT NONE
|
||||||
|
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
assert active_projects is None
|
||||||
|
|
||||||
|
## MAKE CALL
|
||||||
|
await dynamic_rate_limit_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_auth,
|
||||||
|
cache=DualCache(),
|
||||||
|
data={"model": model},
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
## INITIAL ACTIVE PROJECTS - ASSERT 1
|
||||||
|
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
assert active_projects == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_projects", [2])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_projects(
|
||||||
|
dynamic_rate_limit_handler, mock_response, num_projects
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
If 2 active project
|
||||||
|
|
||||||
|
it should split 50% each
|
||||||
|
|
||||||
|
- assert available tpm is 0 after 50%+1 tpm calls
|
||||||
|
"""
|
||||||
|
model = "my-fake-model"
|
||||||
|
model_tpm = 50
|
||||||
|
total_tokens_per_call = 10
|
||||||
|
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
|
||||||
|
|
||||||
|
available_tpm_per_project = int(model_tpm / num_projects)
|
||||||
|
|
||||||
|
## SET CACHE W/ ACTIVE PROJECTS
|
||||||
|
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
|
||||||
|
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||||
|
model=model, value=projects
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
mock_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(
|
||||||
|
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
prev_availability: Optional[int] = None
|
||||||
|
|
||||||
|
print("expected_runs: {}".format(expected_runs))
|
||||||
|
for i in range(expected_runs + 1):
|
||||||
|
# check availability
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
## assert availability updated
|
||||||
|
if prev_availability is not None and availability is not None:
|
||||||
|
assert (
|
||||||
|
availability == prev_availability - step_tokens_per_call_per_project
|
||||||
|
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
|
||||||
|
availability,
|
||||||
|
prev_availability - 10,
|
||||||
|
i,
|
||||||
|
step_tokens_per_call_per_project,
|
||||||
|
model_tpm,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"prev_availability={}, availability={}".format(
|
||||||
|
prev_availability, availability
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_availability = availability
|
||||||
|
|
||||||
|
# make call
|
||||||
|
await llm_router.acompletion(
|
||||||
|
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
# check availability
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
assert availability == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_projects", [2])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_projects_e2e(
|
||||||
|
dynamic_rate_limit_handler, mock_response, num_projects
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
2 parallel calls with different keys, same model
|
||||||
|
|
||||||
|
If 2 active project
|
||||||
|
|
||||||
|
it should split 50% each
|
||||||
|
|
||||||
|
- assert available tpm is 0 after 50%+1 tpm calls
|
||||||
|
"""
|
||||||
|
model = "my-fake-model"
|
||||||
|
model_tpm = 50
|
||||||
|
total_tokens_per_call = 10
|
||||||
|
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
|
||||||
|
|
||||||
|
available_tpm_per_project = int(model_tpm / num_projects)
|
||||||
|
|
||||||
|
## SET CACHE W/ ACTIVE PROJECTS
|
||||||
|
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
|
||||||
|
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||||
|
model=model, value=projects
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
mock_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(
|
||||||
|
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
|
prev_availability: Optional[int] = None
|
||||||
|
|
||||||
|
print("expected_runs: {}".format(expected_runs))
|
||||||
|
for i in range(expected_runs + 1):
|
||||||
|
# check availability
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
## assert availability updated
|
||||||
|
if prev_availability is not None and availability is not None:
|
||||||
|
assert (
|
||||||
|
availability == prev_availability - step_tokens_per_call_per_project
|
||||||
|
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
|
||||||
|
availability,
|
||||||
|
prev_availability - 10,
|
||||||
|
i,
|
||||||
|
step_tokens_per_call_per_project,
|
||||||
|
model_tpm,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"prev_availability={}, availability={}".format(
|
||||||
|
prev_availability, availability
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_availability = availability
|
||||||
|
|
||||||
|
# make call
|
||||||
|
await llm_router.acompletion(
|
||||||
|
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
# check availability
|
||||||
|
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
assert availability == 0
|
|
@ -1730,3 +1730,99 @@ async def test_router_text_completion_client():
|
||||||
print(responses)
|
print(responses)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_response() -> litellm.ModelResponse:
|
||||||
|
return litellm.ModelResponse(
|
||||||
|
**{
|
||||||
|
"id": "chatcmpl-abc123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1699896916,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abc123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": '{\n"location": "Boston, MA"\n}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_model_usage(mock_response):
|
||||||
|
"""
|
||||||
|
Test if tracking used model tpm works as expected
|
||||||
|
"""
|
||||||
|
model = "my-fake-model"
|
||||||
|
model_tpm = 100
|
||||||
|
setattr(
|
||||||
|
mock_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"mock_response: {mock_response}")
|
||||||
|
model_tpm = 100
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": model,
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "my-key",
|
||||||
|
"api_base": "my-base",
|
||||||
|
"tpm": model_tpm,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed_fails = 1 # allow for changing b/w minutes
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
try:
|
||||||
|
_ = await llm_router.acompletion(
|
||||||
|
model=model, messages=[{"role": "user", "content": "Hey!"}]
|
||||||
|
)
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
initial_usage = await llm_router.get_model_group_usage(model_group=model)
|
||||||
|
|
||||||
|
# completion call - 10 tokens
|
||||||
|
_ = await llm_router.acompletion(
|
||||||
|
model=model, messages=[{"role": "user", "content": "Hey!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
updated_usage = await llm_router.get_model_group_usage(model_group=model)
|
||||||
|
|
||||||
|
assert updated_usage == initial_usage + 10 # type: ignore
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
if allowed_fails > 0:
|
||||||
|
print(
|
||||||
|
f"Decrementing allowed_fails: {allowed_fails}.\nReceived error - {str(e)}"
|
||||||
|
)
|
||||||
|
allowed_fails -= 1
|
||||||
|
else:
|
||||||
|
print(f"allowed_fails: {allowed_fails}")
|
||||||
|
raise e
|
||||||
|
|
|
@ -442,6 +442,8 @@ class ModelGroupInfo(BaseModel):
|
||||||
"chat", "embedding", "completion", "image_generation", "audio_transcription"
|
"chat", "embedding", "completion", "image_generation", "audio_transcription"
|
||||||
]
|
]
|
||||||
] = Field(default="chat")
|
] = Field(default="chat")
|
||||||
|
tpm: Optional[int] = None
|
||||||
|
rpm: Optional[int] = None
|
||||||
supports_parallel_function_calling: bool = Field(default=False)
|
supports_parallel_function_calling: bool = Field(default=False)
|
||||||
supports_vision: bool = Field(default=False)
|
supports_vision: bool = Field(default=False)
|
||||||
supports_function_calling: bool = Field(default=False)
|
supports_function_calling: bool = Field(default=False)
|
||||||
|
|
|
@ -340,14 +340,15 @@ def function_setup(
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||||
|
|
||||||
function_id = kwargs["id"] if "id" in kwargs else None
|
function_id = kwargs["id"] if "id" in kwargs else None
|
||||||
|
|
||||||
if len(litellm.callbacks) > 0:
|
if len(litellm.callbacks) > 0:
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
# check if callback is a string - e.g. "lago", "openmeter"
|
# check if callback is a string - e.g. "lago", "openmeter"
|
||||||
if isinstance(callback, str):
|
if isinstance(callback, str):
|
||||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class(
|
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||||
callback
|
callback, internal_usage_cache=None, llm_router=None
|
||||||
)
|
)
|
||||||
if any(
|
if any(
|
||||||
isinstance(cb, type(callback))
|
isinstance(cb, type(callback))
|
||||||
|
@ -3895,12 +3896,16 @@ def get_formatted_prompt(
|
||||||
|
|
||||||
|
|
||||||
def get_response_string(response_obj: ModelResponse) -> str:
|
def get_response_string(response_obj: ModelResponse) -> str:
|
||||||
_choices: List[Choices] = response_obj.choices # type: ignore
|
_choices: List[Union[Choices, StreamingChoices]] = response_obj.choices
|
||||||
|
|
||||||
response_str = ""
|
response_str = ""
|
||||||
for choice in _choices:
|
for choice in _choices:
|
||||||
if choice.message.content is not None:
|
if isinstance(choice, Choices):
|
||||||
response_str += choice.message.content
|
if choice.message.content is not None:
|
||||||
|
response_str += choice.message.content
|
||||||
|
elif isinstance(choice, StreamingChoices):
|
||||||
|
if choice.delta.content is not None:
|
||||||
|
response_str += choice.delta.content
|
||||||
|
|
||||||
return response_str
|
return response_str
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue