LiteLLM Minor Fixes and Improvements (11/09/2024) (#5634)

* fix(caching.py): set ttl for async_increment cache

fixes issue where ttl for redis client was not being set on increment_cache

Fixes https://github.com/BerriAI/litellm/issues/5609

* fix(caching.py): fix increment cache w/ ttl for sync increment cache on redis

Fixes https://github.com/BerriAI/litellm/issues/5609

* fix(router.py): support adding retry policy + allowed fails policy via config.yaml

* fix(router.py): don't cooldown single deployments

No point, as there's no other deployment to loadbalance with.

* fix(user_api_key_auth.py): support setting allowed email domains on jwt tokens

Closes https://github.com/BerriAI/litellm/issues/5605

* docs(token_auth.md): add user upsert + allowed email domain to jwt auth docs

* fix(litellm_pre_call_utils.py): fix dynamic key logging when team id is set

Fixes issue where key logging would not be set if team metadata was not none

* fix(secret_managers/main.py): load environment variables correctly

Fixes issue where os.environ/ was not being loaded correctly

* test(test_router.py): fix test

* feat(spend_tracking_utils.py): support logging additional usage params - e.g. prompt caching values for deepseek

* test: fix tests

* test: fix test

* test: fix test

* test: fix test

* test: fix test
This commit is contained in:
Krish Dholakia 2024-09-11 22:36:06 -07:00 committed by GitHub
parent 70100d716b
commit 98c34a7e27
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 745 additions and 114 deletions

View file

@ -1,12 +1,12 @@
repos: repos:
- repo: local - repo: local
hooks: hooks:
- id: mypy # - id: mypy
name: mypy # name: mypy
entry: python3 -m mypy --ignore-missing-imports # entry: python3 -m mypy --ignore-missing-imports
language: system # language: system
types: [python] # types: [python]
files: ^litellm/ # files: ^litellm/
- id: isort - id: isort
name: isort name: isort
entry: isort entry: isort

View file

@ -243,3 +243,17 @@ curl --location 'http://0.0.0.0:4000/team/unblock' \
}' }'
``` ```
## Advanced - Upsert Users + Allowed Email Domains
Allow users who belong to a specific email domain, automatic access to the proxy.
```yaml
general_settings:
master_key: sk-1234
enable_jwt_auth: True
litellm_jwtauth:
user_email_jwt_field: "email" # 👈 checks 'email' field in jwt payload
user_allowed_email_domain: "my-co.com" # allows user@my-co.com to call proxy
user_id_upsert: true # 👈 upserts the user to db, if valid email but not in db
```

View file

@ -1038,6 +1038,12 @@ print(f"response: {response}")
- Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved - Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
- Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment - Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment
[**See All Exception Types**](https://github.com/BerriAI/litellm/blob/ccda616f2f881375d4e8586c76fe4662909a7d22/litellm/types/router.py#L436)
<Tabs>
<TabItem value="sdk" label="SDK">
Example: Example:
```python ```python
@ -1101,6 +1107,24 @@ response = await router.acompletion(
) )
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
```yaml
router_settings:
retry_policy: {
"BadRequestErrorRetries": 3,
"ContentPolicyViolationErrorRetries": 4
}
allowed_fails_policy: {
"ContentPolicyViolationErrorAllowedFails": 1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
"RateLimitErrorAllowedFails": 100 # Allow 100 RateLimitErrors before cooling down a deployment
}
```
</TabItem>
</Tabs>
### Fallbacks ### Fallbacks

View file

@ -304,40 +304,25 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}" f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
) )
def increment_cache(self, key, value: int, **kwargs) -> int: def increment_cache(
self, key, value: int, ttl: Optional[float] = None, **kwargs
) -> int:
_redis_client = self.redis_client _redis_client = self.redis_client
start_time = time.time() start_time = time.time()
try: try:
result = _redis_client.incr(name=key, amount=value) result = _redis_client.incr(name=key, amount=value)
## LOGGING ##
end_time = time.time() if ttl is not None:
_duration = end_time - start_time # check if key already has ttl, if not -> set ttl
asyncio.create_task( current_ttl = _redis_client.ttl(key)
self.service_logger_obj.service_success_hook( if current_ttl == -1:
service=ServiceTypes.REDIS, # Key has no expiration
duration=_duration, _redis_client.expire(key, ttl)
call_type="increment_cache",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
return result return result
except Exception as e: except Exception as e:
## LOGGING ## ## LOGGING ##
end_time = time.time() end_time = time.time()
_duration = end_time - start_time _duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="increment_cache",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
verbose_logger.error( verbose_logger.error(
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s", "LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
str(e), str(e),
@ -606,12 +591,22 @@ class RedisCache(BaseCache):
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer() # logging done in here await self.flush_cache_buffer() # logging done in here
async def async_increment(self, key, value: float, **kwargs) -> float: async def async_increment(
self, key, value: float, ttl: Optional[float] = None, **kwargs
) -> float:
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
start_time = time.time() start_time = time.time()
try: try:
async with _redis_client as redis_client: async with _redis_client as redis_client:
result = await redis_client.incrbyfloat(name=key, amount=value) result = await redis_client.incrbyfloat(name=key, amount=value)
if ttl is not None:
# check if key already has ttl, if not -> set ttl
current_ttl = await redis_client.ttl(key)
if current_ttl == -1:
# Key has no expiration
await redis_client.expire(key, ttl)
## LOGGING ## ## LOGGING ##
end_time = time.time() end_time = time.time()
_duration = end_time - start_time _duration = end_time - start_time

View file

@ -1609,15 +1609,24 @@ class Logging:
""" """
from litellm.types.router import RouterErrors from litellm.types.router import RouterErrors
litellm_params: dict = self.model_call_details.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
## BASE CASE ## check if rate limit error for model group size 1
is_base_case = False
if metadata.get("model_group_size") is not None:
model_group_size = metadata.get("model_group_size")
if isinstance(model_group_size, int) and model_group_size == 1:
is_base_case = True
## check if special error ## ## check if special error ##
if RouterErrors.no_deployments_available.value not in str(exception): if (
RouterErrors.no_deployments_available.value not in str(exception)
and is_base_case is False
):
return return
## get original model group ## ## get original model group ##
litellm_params: dict = self.model_call_details.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
model_group = metadata.get("model_group") or None model_group = metadata.get("model_group") or None
for callback in litellm._async_failure_callback: for callback in litellm._async_failure_callback:
if isinstance(callback, CustomLogger): # custom logger class if isinstance(callback, CustomLogger): # custom logger class

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -386,6 +386,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`. - team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
- team_allowed_routes: list of allowed routes for proxy team roles. - team_allowed_routes: list of allowed routes for proxy team roles.
- user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees. - user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
- user_email_jwt_field: The field in the JWT token that stores the user email (maps to `LiteLLMUserTable`). Use this for internal employees.
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers. - end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
- public_key_ttl: Default - 600s. TTL for caching public JWT keys. - public_key_ttl: Default - 600s. TTL for caching public JWT keys.
@ -417,6 +419,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
) )
org_id_jwt_field: Optional[str] = None org_id_jwt_field: Optional[str] = None
user_id_jwt_field: Optional[str] = None user_id_jwt_field: Optional[str] = None
user_email_jwt_field: Optional[str] = None
user_allowed_email_domain: Optional[str] = None
user_id_upsert: bool = Field( user_id_upsert: bool = Field(
default=False, description="If user doesn't exist, upsert them into the db." default=False, description="If user doesn't exist, upsert them into the db."
) )
@ -1690,6 +1694,9 @@ class SpendLogsMetadata(TypedDict):
Specific metadata k,v pairs logged to spendlogs for easier cost tracking Specific metadata k,v pairs logged to spendlogs for easier cost tracking
""" """
additional_usage_values: Optional[
dict
] # covers provider-specific usage information - e.g. prompt caching
user_api_key: Optional[str] user_api_key: Optional[str]
user_api_key_alias: Optional[str] user_api_key_alias: Optional[str]
user_api_key_team_id: Optional[str] user_api_key_team_id: Optional[str]

View file

@ -78,6 +78,19 @@ class JWTHandler:
return False return False
return True return True
def is_enforced_email_domain(self) -> bool:
"""
Returns:
- True: if 'user_allowed_email_domain' is set
- False: if 'user_allowed_email_domain' is None
"""
if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance(
self.litellm_jwtauth.user_allowed_email_domain, str
):
return True
return False
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
if self.litellm_jwtauth.team_id_jwt_field is not None: if self.litellm_jwtauth.team_id_jwt_field is not None:
@ -90,12 +103,14 @@ class JWTHandler:
team_id = default_value team_id = default_value
return team_id return team_id
def is_upsert_user_id(self) -> bool: def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool:
""" """
Returns: Returns:
- True: if 'user_id_upsert' is set - True: if 'user_id_upsert' is set AND valid_user_email is not False
- False: if not - False: if not
""" """
if valid_user_email is False:
return False
return self.litellm_jwtauth.user_id_upsert return self.litellm_jwtauth.user_id_upsert
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
@ -103,11 +118,23 @@ class JWTHandler:
if self.litellm_jwtauth.user_id_jwt_field is not None: if self.litellm_jwtauth.user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.user_id_jwt_field] user_id = token[self.litellm_jwtauth.user_id_jwt_field]
else: else:
user_id = None user_id = default_value
except KeyError: except KeyError:
user_id = default_value user_id = default_value
return user_id return user_id
def get_user_email(
self, token: dict, default_value: Optional[str]
) -> Optional[str]:
try:
if self.litellm_jwtauth.user_email_jwt_field is not None:
user_email = token[self.litellm_jwtauth.user_email_jwt_field]
else:
user_email = None
except KeyError:
user_email = default_value
return user_email
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
if self.litellm_jwtauth.org_id_jwt_field is not None: if self.litellm_jwtauth.org_id_jwt_field is not None:
@ -183,6 +210,16 @@ class JWTHandler:
return public_key return public_key
def is_allowed_domain(self, user_email: str) -> bool:
if self.litellm_jwtauth.user_allowed_email_domain is None:
return True
email_domain = user_email.split("@")[-1] # Extract domain from email
if email_domain == self.litellm_jwtauth.user_allowed_email_domain:
return True
else:
return False
async def auth_jwt(self, token: str) -> dict: async def auth_jwt(self, token: str) -> dict:
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html # Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret # "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret

View file

@ -250,6 +250,7 @@ async def user_api_key_auth(
raise Exception( raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
) )
# get team id # get team id
team_id = jwt_handler.get_team_id( team_id = jwt_handler.get_team_id(
token=jwt_valid_token, default_value=None token=jwt_valid_token, default_value=None
@ -296,10 +297,30 @@ async def user_api_key_auth(
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
) )
# [OPTIONAL] allowed user email domains
valid_user_email: Optional[bool] = None
user_email: Optional[str] = None
if jwt_handler.is_enforced_email_domain():
"""
if 'allowed_email_subdomains' is set,
- checks if token contains 'email' field
- checks if 'email' is from an allowed domain
"""
user_email = jwt_handler.get_user_email(
token=jwt_valid_token, default_value=None
)
if user_email is None:
valid_user_email = False
else:
valid_user_email = jwt_handler.is_allowed_domain(
user_email=user_email
)
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable` # [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None user_object = None
user_id = jwt_handler.get_user_id( user_id = jwt_handler.get_user_id(
token=jwt_valid_token, default_value=None token=jwt_valid_token, default_value=user_email
) )
if user_id is not None: if user_id is not None:
# get the user object # get the user object
@ -307,11 +328,12 @@ async def user_api_key_auth(
user_id=user_id, user_id=user_id,
prisma_client=prisma_client, prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
user_id_upsert=jwt_handler.is_upsert_user_id(), user_id_upsert=jwt_handler.is_upsert_user_id(
valid_user_email=valid_user_email
),
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
) )
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None end_user_object = None
end_user_id = jwt_handler.get_end_user_id( end_user_id = jwt_handler.get_end_user_id(
@ -802,7 +824,7 @@ async def user_api_key_auth(
# collect information for alerting # # collect information for alerting #
#################################### ####################################
user_email: Optional[str] = None user_email = None
# Check if the token has any user id information # Check if the token has any user id information
if user_obj is not None: if user_obj is not None:
user_email = user_obj.user_email user_email = user_obj.user_email

View file

@ -107,7 +107,16 @@ def _get_dynamic_logging_metadata(
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
) -> Optional[TeamCallbackMetadata]: ) -> Optional[TeamCallbackMetadata]:
callback_settings_obj: Optional[TeamCallbackMetadata] = None callback_settings_obj: Optional[TeamCallbackMetadata] = None
if user_api_key_dict.team_metadata is not None: if (
user_api_key_dict.metadata is not None
and "logging" in user_api_key_dict.metadata
):
for item in user_api_key_dict.metadata["logging"]:
callback_settings_obj = convert_key_logging_metadata_to_callback(
data=AddTeamCallback(**item),
team_callback_settings_obj=callback_settings_obj,
)
elif user_api_key_dict.team_metadata is not None:
team_metadata = user_api_key_dict.team_metadata team_metadata = user_api_key_dict.team_metadata
if "callback_settings" in team_metadata: if "callback_settings" in team_metadata:
callback_settings = team_metadata.get("callback_settings", None) or {} callback_settings = team_metadata.get("callback_settings", None) or {}
@ -124,15 +133,7 @@ def _get_dynamic_logging_metadata(
} }
} }
""" """
elif (
user_api_key_dict.metadata is not None
and "logging" in user_api_key_dict.metadata
):
for item in user_api_key_dict.metadata["logging"]:
callback_settings_obj = convert_key_logging_metadata_to_callback(
data=AddTeamCallback(**item),
team_callback_settings_obj=callback_settings_obj,
)
return callback_settings_obj return callback_settings_obj

View file

@ -84,6 +84,7 @@ def get_logging_payload(
user_api_key_team_alias=None, user_api_key_team_alias=None,
spend_logs_metadata=None, spend_logs_metadata=None,
requester_ip_address=None, requester_ip_address=None,
additional_usage_values=None,
) )
if isinstance(metadata, dict): if isinstance(metadata, dict):
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -100,6 +101,13 @@ def get_logging_payload(
} }
) )
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
additional_usage_values = {}
for k, v in usage.items():
if k not in special_usage_fields:
additional_usage_values.update({k: v})
clean_metadata["additional_usage_values"] = additional_usage_values
if litellm.cache is not None: if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs) cache_key = litellm.cache.get_cache_key(**kwargs)
else: else:

View file

@ -161,10 +161,10 @@ class Router:
enable_tag_filtering: bool = False, enable_tag_filtering: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request retry_after: int = 0, # min time to wait before retrying a failed request
retry_policy: Optional[ retry_policy: Optional[
RetryPolicy Union[RetryPolicy, dict]
] = None, # set custom retries for different exceptions ] = None, # set custom retries for different exceptions
model_group_retry_policy: Optional[ model_group_retry_policy: Dict[
Dict[str, RetryPolicy] str, RetryPolicy
] = {}, # set custom retry policies based on model group ] = {}, # set custom retry policies based on model group
allowed_fails: Optional[ allowed_fails: Optional[
int int
@ -263,7 +263,7 @@ class Router:
self.debug_level = debug_level self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks self.enable_pre_call_checks = enable_pre_call_checks
self.enable_tag_filtering = enable_tag_filtering self.enable_tag_filtering = enable_tag_filtering
if self.set_verbose == True: if self.set_verbose is True:
if debug_level == "INFO": if debug_level == "INFO":
verbose_router_logger.setLevel(logging.INFO) verbose_router_logger.setLevel(logging.INFO)
elif debug_level == "DEBUG": elif debug_level == "DEBUG":
@ -454,11 +454,35 @@ class Router:
) )
self.routing_strategy_args = routing_strategy_args self.routing_strategy_args = routing_strategy_args
self.retry_policy: Optional[RetryPolicy] = retry_policy self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None:
if isinstance(retry_policy, dict):
self.retry_policy = RetryPolicy(**retry_policy)
elif isinstance(retry_policy, RetryPolicy):
self.retry_policy = retry_policy
verbose_router_logger.info(
"\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format(
self.retry_policy.model_dump(exclude_none=True)
)
)
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy model_group_retry_policy
) )
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
if allowed_fails_policy is not None:
if isinstance(allowed_fails_policy, dict):
self.allowed_fails_policy = AllowedFailsPolicy(**allowed_fails_policy)
elif isinstance(allowed_fails_policy, AllowedFailsPolicy):
self.allowed_fails_policy = allowed_fails_policy
verbose_router_logger.info(
"\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format(
self.allowed_fails_policy.model_dump(exclude_none=True)
)
)
self.alerting_config: Optional[AlertingConfig] = alerting_config self.alerting_config: Optional[AlertingConfig] = alerting_config
if self.alerting_config is not None: if self.alerting_config is not None:
self._initialize_alerting() self._initialize_alerting()
@ -3003,6 +3027,13 @@ class Router:
model_group = kwargs.get("model") model_group = kwargs.get("model")
num_retries = kwargs.pop("num_retries") num_retries = kwargs.pop("num_retries")
## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking
_metadata: dict = kwargs.get("metadata") or {}
if "model_group" in _metadata and isinstance(_metadata["model_group"], str):
model_list = self.get_model_list(model_name=_metadata["model_group"])
if model_list is not None:
_metadata.update({"model_group_size": len(model_list)})
verbose_router_logger.debug( verbose_router_logger.debug(
f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}" f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
) )
@ -3165,6 +3196,7 @@ class Router:
If it fails after num_retries, fall back to another model group If it fails after num_retries, fall back to another model group
""" """
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model") model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks) fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get( context_window_fallbacks = kwargs.get(
@ -3173,6 +3205,7 @@ class Router:
content_policy_fallbacks = kwargs.get( content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks "content_policy_fallbacks", self.content_policy_fallbacks
) )
try: try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception( raise Exception(
@ -3324,6 +3357,9 @@ class Router:
f"Inside function with retries: args - {args}; kwargs - {kwargs}" f"Inside function with retries: args - {args}; kwargs - {kwargs}"
) )
original_function = kwargs.pop("original_function") original_function = kwargs.pop("original_function")
mock_testing_rate_limit_error = kwargs.pop(
"mock_testing_rate_limit_error", None
)
num_retries = kwargs.pop("num_retries") num_retries = kwargs.pop("num_retries")
fallbacks = kwargs.pop("fallbacks", self.fallbacks) fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop( context_window_fallbacks = kwargs.pop(
@ -3332,9 +3368,22 @@ class Router:
content_policy_fallbacks = kwargs.pop( content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks "content_policy_fallbacks", self.content_policy_fallbacks
) )
model_group = kwargs.get("model")
try: try:
# if the function call is successful, no exception will be raised and we'll break out of the loop # if the function call is successful, no exception will be raised and we'll break out of the loop
if (
mock_testing_rate_limit_error is not None
and mock_testing_rate_limit_error is True
):
verbose_router_logger.info(
"litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError."
)
raise litellm.RateLimitError(
model=model_group,
llm_provider="",
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
)
response = original_function(*args, **kwargs) response = original_function(*args, **kwargs)
return response return response
except Exception as e: except Exception as e:
@ -3571,17 +3620,26 @@ class Router:
) # don't change existing ttl ) # don't change existing ttl
def _is_cooldown_required( def _is_cooldown_required(
self, exception_status: Union[str, int], exception_str: Optional[str] = None self,
): model_id: str,
exception_status: Union[str, int],
exception_str: Optional[str] = None,
) -> bool:
""" """
A function to determine if a cooldown is required based on the exception status. A function to determine if a cooldown is required based on the exception status.
Parameters: Parameters:
model_id (str) The id of the model in the model list
exception_status (Union[str, int]): The status of the exception. exception_status (Union[str, int]): The status of the exception.
Returns: Returns:
bool: True if a cooldown is required, False otherwise. bool: True if a cooldown is required, False otherwise.
""" """
## BASE CASE - single deployment
model_group = self.get_model_group(id=model_id)
if model_group is not None and len(model_group) == 1:
return False
try: try:
ignored_strings = ["APIConnectionError"] ignored_strings = ["APIConnectionError"]
if ( if (
@ -3677,7 +3735,9 @@ class Router:
if ( if (
self._is_cooldown_required( self._is_cooldown_required(
exception_status=exception_status, exception_str=str(original_exception) model_id=deployment,
exception_status=exception_status,
exception_str=str(original_exception),
) )
is False is False
): ):
@ -3690,7 +3750,9 @@ class Router:
exception=original_exception, exception=original_exception,
) )
allowed_fails = _allowed_fails if _allowed_fails is not None else self.allowed_fails allowed_fails = (
_allowed_fails if _allowed_fails is not None else self.allowed_fails
)
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
@ -4298,6 +4360,18 @@ class Router:
return model return model
return None return None
def get_model_group(self, id: str) -> Optional[List]:
"""
Return list of all models in the same model group as that model id
"""
model_info = self.get_model_info(id=id)
if model_info is None:
return None
model_name = model_info["model_name"]
return self.get_model_list(model_name=model_name)
def _set_model_group_info( def _set_model_group_info(
self, model_group: str, user_facing_model_group_name: str self, model_group: str, user_facing_model_group_name: str
) -> Optional[ModelGroupInfo]: ) -> Optional[ModelGroupInfo]:
@ -5528,18 +5602,19 @@ class Router:
ContentPolicyViolationErrorRetries: Optional[int] = None ContentPolicyViolationErrorRetries: Optional[int] = None
""" """
# if we can find the exception then in the retry policy -> return the number of retries # if we can find the exception then in the retry policy -> return the number of retries
retry_policy = self.retry_policy retry_policy: Optional[RetryPolicy] = self.retry_policy
if ( if (
self.model_group_retry_policy is not None self.model_group_retry_policy is not None
and model_group is not None and model_group is not None
and model_group in self.model_group_retry_policy and model_group in self.model_group_retry_policy
): ):
retry_policy = self.model_group_retry_policy.get(model_group, None) retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore
if retry_policy is None: if retry_policy is None:
return None return None
if isinstance(retry_policy, dict): if isinstance(retry_policy, dict):
retry_policy = RetryPolicy(**retry_policy) retry_policy = RetryPolicy(**retry_policy)
if ( if (
isinstance(exception, litellm.BadRequestError) isinstance(exception, litellm.BadRequestError)
and retry_policy.BadRequestErrorRetries is not None and retry_policy.BadRequestErrorRetries is not None

View file

@ -29,6 +29,27 @@ def _is_base64(s):
return False return False
def str_to_bool(value: str) -> Optional[bool]:
"""
Converts a string to a boolean if it's a recognized boolean string.
Returns None if the string is not a recognized boolean value.
:param value: The string to be checked.
:return: True or False if the string is a recognized boolean, otherwise None.
"""
true_values = {"true"}
false_values = {"false"}
value_lower = value.strip().lower()
if value_lower in true_values:
return True
elif value_lower in false_values:
return False
else:
return None
def get_secret( def get_secret(
secret_name: str, secret_name: str,
default_value: Optional[Union[str, bool]] = None, default_value: Optional[Union[str, bool]] = None,
@ -257,17 +278,12 @@ def get_secret(
return secret return secret
else: else:
secret = os.environ.get(secret_name) secret = os.environ.get(secret_name)
try: secret_value_as_bool = str_to_bool(secret) if secret is not None else None
secret_value_as_bool = ( if secret_value_as_bool is not None and isinstance(
ast.literal_eval(secret) if secret is not None else None secret_value_as_bool, bool
) ):
if isinstance(secret_value_as_bool, bool): return secret_value_as_bool
return secret_value_as_bool else:
else:
return secret
except Exception:
if default_value is not None:
return default_value
return secret return secret
except Exception as e: except Exception as e:
if default_value is not None: if default_value is not None:

View file

@ -54,6 +54,7 @@ VERTEX_MODELS_TO_NOT_TEST = [
"gemini-flash-experimental", "gemini-flash-experimental",
"gemini-1.5-flash-exp-0827", "gemini-1.5-flash-exp-0827",
"gemini-pro-flash", "gemini-pro-flash",
"gemini-1.5-flash-exp-0827",
] ]

View file

@ -38,6 +38,8 @@ from litellm.integrations.custom_logger import CustomLogger
## 1. router.completion() + router.embeddings() ## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings ## 2. proxy.completions + proxy.embeddings
litellm.num_retries = 0
class CompletionCustomHandler( class CompletionCustomHandler(
CustomLogger CustomLogger
@ -401,7 +403,7 @@ async def test_async_chat_azure():
"rpm": 1800, "rpm": 1800,
}, },
] ]
router = Router(model_list=model_list) # type: ignore router = Router(model_list=model_list, num_retries=0) # type: ignore
response = await router.acompletion( response = await router.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
@ -413,7 +415,7 @@ async def test_async_chat_azure():
) # pre, post, success ) # pre, post, success
# streaming # streaming
litellm.callbacks = [customHandler_streaming_azure_router] litellm.callbacks = [customHandler_streaming_azure_router]
router2 = Router(model_list=model_list) # type: ignore router2 = Router(model_list=model_list, num_retries=0) # type: ignore
response = await router2.acompletion( response = await router2.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
@ -443,7 +445,7 @@ async def test_async_chat_azure():
}, },
] ]
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore router3 = Router(model_list=model_list, num_retries=0) # type: ignore
try: try:
response = await router3.acompletion( response = await router3.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@ -505,7 +507,7 @@ async def test_async_embedding_azure():
}, },
] ]
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore router3 = Router(model_list=model_list, num_retries=0) # type: ignore
try: try:
response = await router3.aembedding( response = await router3.aembedding(
model="azure-embedding-model", input=["hello from litellm!"] model="azure-embedding-model", input=["hello from litellm!"]
@ -678,22 +680,21 @@ async def test_rate_limit_error_callback():
pass pass
with patch.object( with patch.object(
customHandler, "log_model_group_rate_limit_error", new=MagicMock() customHandler, "log_model_group_rate_limit_error", new=AsyncMock()
) as mock_client: ) as mock_client:
print( print(
f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}" f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}"
) )
for _ in range(3): try:
try: _ = await router.acompletion(
_ = await router.acompletion( model="my-test-gpt",
model="my-test-gpt", messages=[{"role": "user", "content": "Hey, how's it going?"}],
messages=[{"role": "user", "content": "Hey, how's it going?"}], litellm_logging_obj=litellm_logging_obj,
litellm_logging_obj=litellm_logging_obj, )
) except (litellm.RateLimitError, ValueError):
except (litellm.RateLimitError, ValueError): pass
pass
await asyncio.sleep(3) await asyncio.sleep(3)
mock_client.assert_called_once() mock_client.assert_called_once()

View file

@ -23,8 +23,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from fastapi import Request from fastapi import Request
import litellm
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.team_endpoints import new_team from litellm.proxy.management_endpoints.team_endpoints import new_team
from litellm.proxy.proxy_server import chat_completion from litellm.proxy.proxy_server import chat_completion
@ -816,8 +817,6 @@ async def test_allowed_routes_admin(prisma_client, audience):
raise e raise e
from unittest.mock import AsyncMock
import pytest import pytest
@ -844,3 +843,148 @@ async def test_team_cache_update_called():
await asyncio.sleep(3) await asyncio.sleep(3)
mock_call_cache.assert_awaited_once() mock_call_cache.assert_awaited_once()
@pytest.fixture
def public_jwt_key():
import json
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
return {"private_key": private_key, "public_jwk": public_jwk}
def mock_user_object(*args, **kwargs):
print("Args: {}".format(args))
print("kwargs: {}".format(kwargs))
assert kwargs["user_id_upsert"] is True
@pytest.mark.parametrize(
"user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)]
)
@pytest.mark.asyncio
async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
"""
Allow anyone with an `@xyz.com` email make a request to the proxy.
Relevant issue: https://github.com/BerriAI/litellm/issues/5605
"""
import jwt
from starlette.datastructures import URL
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
public_jwk = public_jwt_key["public_jwk"]
private_key = public_jwt_key["private_key"]
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
user_email_jwt_field="email",
user_allowed_email_domain="berri.ai",
user_id_upsert=True,
)
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": "litellm-proxy",
"email": user_email,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# Expect the call to succeed
response = await jwt_handler.auth_jwt(token=token)
assert response is not None # Adjust this based on your actual response check
## RUN IT THROUGH USER API KEY AUTH
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
setattr(litellm.proxy.proxy_server, "prisma_client", {})
# AsyncMock(
# return_value=LiteLLM_UserTable(
# spend=0, user_id=user_email, max_budget=None, user_email=user_email
# )
# ),
with patch.object(
litellm.proxy.auth.user_api_key_auth,
"get_user_object",
side_effect=mock_user_object,
) as mock_client:
if should_work:
# Expect the call to succeed
result = await user_api_key_auth(request=request, api_key=bearer_token)
assert result is not None # Adjust this based on your actual response check
else:
# Expect the call to fail
with pytest.raises(
Exception
): # Replace with the actual exception raised on failure
resp = await user_api_key_auth(request=request, api_key=bearer_token)
print(resp)

View file

@ -1,16 +1,17 @@
import sys
import os
import io import io
import os
import sys
sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion
import litellm import litellm
from litellm import completion
litellm.failure_callback = ["lunary"] litellm.failure_callback = ["lunary"]
litellm.success_callback = ["lunary"] litellm.success_callback = ["lunary"]
litellm.set_verbose = True litellm.set_verbose = True
def test_lunary_logging(): def test_lunary_logging():
try: try:
response = completion( response = completion(
@ -24,6 +25,7 @@ def test_lunary_logging():
except Exception as e: except Exception as e:
print(e) print(e)
test_lunary_logging() test_lunary_logging()
@ -37,8 +39,6 @@ def test_lunary_template():
except Exception as e: except Exception as e:
print(e) print(e)
test_lunary_template()
def test_lunary_logging_with_metadata(): def test_lunary_logging_with_metadata():
try: try:
@ -50,19 +50,23 @@ def test_lunary_logging_with_metadata():
metadata={ metadata={
"run_name": "litellmRUN", "run_name": "litellmRUN",
"project_name": "litellm-completion", "project_name": "litellm-completion",
"tags": ["tag1", "tag2"] "tags": ["tag1", "tag2"],
}, },
) )
print(response) print(response)
except Exception as e: except Exception as e:
print(e) print(e)
test_lunary_logging_with_metadata()
def test_lunary_with_tools(): def test_lunary_with_tools():
import litellm import litellm
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
}
]
tools = [ tools = [
{ {
"type": "function", "type": "function",
@ -90,13 +94,11 @@ def test_lunary_with_tools():
tools=tools, tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit tool_choice="auto", # auto is default, but we'll be explicit
) )
response_message = response.choices[0].message response_message = response.choices[0].message
print("\nLLM Response:\n", response.choices[0].message) print("\nLLM Response:\n", response.choices[0].message)
test_lunary_with_tools()
def test_lunary_logging_with_streaming_and_metadata(): def test_lunary_logging_with_streaming_and_metadata():
try: try:
response = completion( response = completion(
@ -114,5 +116,3 @@ def test_lunary_logging_with_streaming_and_metadata():
continue continue
except Exception as e: except Exception as e:
print(e) print(e)
test_lunary_logging_with_streaming_and_metadata()

View file

@ -11,8 +11,11 @@ import litellm
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request from litellm.proxy.litellm_pre_call_utils import (
_get_dynamic_logging_metadata,
add_litellm_data_to_request,
)
from litellm.types.utils import SupportedCacheControls from litellm.types.utils import SupportedCacheControls
@ -204,3 +207,87 @@ async def test_add_key_or_team_level_spend_logs_metadata_to_request(
# assert ( # assert (
# new_data["metadata"]["spend_logs_metadata"] == metadata["spend_logs_metadata"] # new_data["metadata"]["spend_logs_metadata"] == metadata["spend_logs_metadata"]
# ) # )
@pytest.mark.parametrize(
"callback_vars",
[
{
"langfuse_host": "https://us.cloud.langfuse.com",
"langfuse_public_key": "pk-lf-9636b7a6-c066",
"langfuse_secret_key": "sk-lf-7cc8b620",
},
{
"langfuse_host": "os.environ/LANGFUSE_HOST_TEMP",
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY_TEMP",
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY_TEMP",
},
],
)
def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
os.environ["LANGFUSE_PUBLIC_KEY_TEMP"] = "pk-lf-9636b7a6-c066"
os.environ["LANGFUSE_SECRET_KEY_TEMP"] = "sk-lf-7cc8b620"
os.environ["LANGFUSE_HOST_TEMP"] = "https://us.cloud.langfuse.com"
user_api_key_dict = UserAPIKeyAuth(
token="6f8688eaff1d37555bb9e9a6390b6d7032b3ab2526ba0152da87128eab956432",
key_name="sk-...63Fg",
key_alias=None,
spend=0.000111,
max_budget=None,
expires=None,
models=[],
aliases={},
config={},
user_id=None,
team_id="ishaan-special-team_e02dd54f-f790-4755-9f93-73734f415898",
max_parallel_requests=None,
metadata={
"logging": [
{
"callback_name": "langfuse",
"callback_type": "success",
"callback_vars": callback_vars,
}
]
},
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
budget_reset_at=None,
allowed_cache_controls=[],
permissions={},
model_spend={},
model_max_budget={},
soft_budget_cooldown=False,
litellm_budget_table=None,
org_id=None,
team_spend=0.000132,
team_alias=None,
team_tpm_limit=None,
team_rpm_limit=None,
team_max_budget=None,
team_models=[],
team_blocked=False,
soft_budget=None,
team_model_aliases=None,
team_member_spend=None,
team_member=None,
team_metadata={},
end_user_id=None,
end_user_tpm_limit=None,
end_user_rpm_limit=None,
end_user_max_budget=None,
last_refreshed_at=1726101560.967527,
api_key="7c305cc48fe72272700dc0d67dc691c2d1f2807490ef5eb2ee1d3a3ca86e12b1",
user_role=LitellmUserRoles.INTERNAL_USER,
allowed_model_region=None,
parent_otel_span=None,
rpm_limit_per_model=None,
tpm_limit_per_model=None,
)
callbacks = _get_dynamic_logging_metadata(user_api_key_dict=user_api_key_dict)
assert callbacks is not None
for var in callbacks.callback_vars.values():
assert "os.environ" not in var

View file

@ -2121,7 +2121,7 @@ def test_router_cooldown_api_connection_error():
except litellm.APIConnectionError as e: except litellm.APIConnectionError as e:
assert ( assert (
Router()._is_cooldown_required( Router()._is_cooldown_required(
exception_status=e.code, exception_str=str(e) model_id="", exception_status=e.code, exception_str=str(e)
) )
is False is False
) )
@ -2272,7 +2272,13 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
"litellm_params": { "litellm_params": {
"model": "openai/text-embedding-ada-002", "model": "openai/text-embedding-ada-002",
}, },
} },
{
"model_name": "text-embedding-ada-002",
"litellm_params": {
"model": "openai/text-embedding-ada-002",
},
},
] ]
) )

View file

@ -21,6 +21,7 @@ import openai
import litellm import litellm
from litellm import Router from litellm import Router
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.types.router import DeploymentTypedDict, LiteLLMParamsTypedDict
@pytest.mark.asyncio @pytest.mark.asyncio
@ -112,3 +113,40 @@ async def test_dynamic_cooldowns():
assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"] assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"]
assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0 assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0
@pytest.mark.parametrize("num_deployments", [1, 2])
def test_single_deployment_no_cooldowns(num_deployments):
"""
Do not cooldown on single deployment.
Cooldown on multiple deployments.
"""
model_list = []
for i in range(num_deployments):
model = DeploymentTypedDict(
model_name="gpt-3.5-turbo",
litellm_params=LiteLLMParamsTypedDict(
model="gpt-3.5-turbo",
),
)
model_list.append(model)
router = Router(model_list=model_list, allowed_fails=0, num_retries=0)
with patch.object(
router.cooldown_cache, "add_deployment_to_cooldown", new=MagicMock()
) as mock_client:
try:
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
if num_deployments == 1:
mock_client.assert_not_called()
else:
mock_client.assert_called_once()

View file

@ -89,6 +89,17 @@ async def test_router_retries_errors(sync_mode, error_type):
"tpm": 240000, "tpm": 240000,
"rpm": 1800, "rpm": 1800,
}, },
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": _api_key,
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
] ]
router = Router(model_list=model_list, allowed_fails=3) router = Router(model_list=model_list, allowed_fails=3)

View file

@ -17,6 +17,8 @@ import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import litellm import litellm
@ -459,3 +461,138 @@ async def test_router_completion_streaming():
- Unit test for sync 'pre_call_checks' - Unit test for sync 'pre_call_checks'
- Unit test for async 'async_pre_call_checks' - Unit test for async 'async_pre_call_checks'
""" """
@pytest.mark.asyncio
async def test_router_caching_ttl():
"""
Confirm caching ttl's work as expected.
Relevant issue: https://github.com/BerriAI/litellm/issues/5609
"""
messages = [
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
]
model = "azure-model"
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"tpm": 1440,
"mock_response": "Hello world",
},
"model_info": {"id": 1},
}
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing-v2",
set_verbose=False,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=os.getenv("REDIS_PORT"),
)
assert router.cache.redis_cache is not None
increment_cache_kwargs = {}
with patch.object(
router.cache.redis_cache,
"async_increment",
new=AsyncMock(),
) as mock_client:
await router.acompletion(model=model, messages=messages)
mock_client.assert_called_once()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
print(f"mock_client.call_args.args: {mock_client.call_args.args}")
increment_cache_kwargs = {
"key": mock_client.call_args.args[0],
"value": mock_client.call_args.args[1],
"ttl": mock_client.call_args.kwargs["ttl"],
}
assert mock_client.call_args.kwargs["ttl"] == 60
## call redis async increment and check if ttl correctly set
await router.cache.redis_cache.async_increment(**increment_cache_kwargs)
_redis_client = router.cache.redis_cache.init_async_client()
async with _redis_client as redis_client:
current_ttl = await redis_client.ttl(increment_cache_kwargs["key"])
assert current_ttl >= 0
print(f"current_ttl: {current_ttl}")
def test_router_caching_ttl_sync():
"""
Confirm caching ttl's work as expected.
Relevant issue: https://github.com/BerriAI/litellm/issues/5609
"""
messages = [
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
]
model = "azure-model"
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"tpm": 1440,
"mock_response": "Hello world",
},
"model_info": {"id": 1},
}
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing-v2",
set_verbose=False,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=os.getenv("REDIS_PORT"),
)
assert router.cache.redis_cache is not None
increment_cache_kwargs = {}
with patch.object(
router.cache.redis_cache,
"increment_cache",
new=MagicMock(),
) as mock_client:
router.completion(model=model, messages=messages)
print(mock_client.call_args_list)
mock_client.assert_called()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
print(f"mock_client.call_args.args: {mock_client.call_args.args}")
increment_cache_kwargs = {
"key": mock_client.call_args.args[0],
"value": mock_client.call_args.args[1],
"ttl": mock_client.call_args.kwargs["ttl"],
}
assert mock_client.call_args.kwargs["ttl"] == 60
## call redis async increment and check if ttl correctly set
router.cache.redis_cache.increment_cache(**increment_cache_kwargs)
_redis_client = router.cache.redis_cache.redis_client
current_ttl = _redis_client.ttl(increment_cache_kwargs["key"])
assert current_ttl >= 0
print(f"current_ttl: {current_ttl}")

View file

@ -4,7 +4,6 @@ import time
import pytest import pytest
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from traceloop.sdk import Traceloop
import litellm import litellm
@ -13,6 +12,8 @@ sys.path.insert(0, os.path.abspath("../.."))
@pytest.fixture() @pytest.fixture()
def exporter(): def exporter():
from traceloop.sdk import Traceloop
exporter = InMemorySpanExporter() exporter = InMemorySpanExporter()
Traceloop.init( Traceloop.init(
app_name="test_litellm", app_name="test_litellm",