diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index a33473b72..d429bc6b8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,12 +1,12 @@
repos:
- repo: local
hooks:
- - id: mypy
- name: mypy
- entry: python3 -m mypy --ignore-missing-imports
- language: system
- types: [python]
- files: ^litellm/
+ # - id: mypy
+ # name: mypy
+ # entry: python3 -m mypy --ignore-missing-imports
+ # language: system
+ # types: [python]
+ # files: ^litellm/
- id: isort
name: isort
entry: isort
diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md
index 659cc6edf..049ea0f98 100644
--- a/docs/my-website/docs/proxy/token_auth.md
+++ b/docs/my-website/docs/proxy/token_auth.md
@@ -243,3 +243,17 @@ curl --location 'http://0.0.0.0:4000/team/unblock' \
}'
```
+
+## Advanced - Upsert Users + Allowed Email Domains
+
+Allow users who belong to a specific email domain, automatic access to the proxy.
+
+```yaml
+general_settings:
+ master_key: sk-1234
+ enable_jwt_auth: True
+ litellm_jwtauth:
+ user_email_jwt_field: "email" # 👈 checks 'email' field in jwt payload
+ user_allowed_email_domain: "my-co.com" # allows user@my-co.com to call proxy
+ user_id_upsert: true # 👈 upserts the user to db, if valid email but not in db
+```
\ No newline at end of file
diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md
index c7c6c3c97..87925516a 100644
--- a/docs/my-website/docs/routing.md
+++ b/docs/my-website/docs/routing.md
@@ -1038,6 +1038,12 @@ print(f"response: {response}")
- Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
- Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment
+[**See All Exception Types**](https://github.com/BerriAI/litellm/blob/ccda616f2f881375d4e8586c76fe4662909a7d22/litellm/types/router.py#L436)
+
+
+
+
+
Example:
```python
@@ -1101,6 +1107,24 @@ response = await router.acompletion(
)
```
+
+
+
+```yaml
+router_settings:
+ retry_policy: {
+ "BadRequestErrorRetries": 3,
+ "ContentPolicyViolationErrorRetries": 4
+ }
+ allowed_fails_policy: {
+ "ContentPolicyViolationErrorAllowedFails": 1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
+ "RateLimitErrorAllowedFails": 100 # Allow 100 RateLimitErrors before cooling down a deployment
+ }
+```
+
+
+
+
### Fallbacks
diff --git a/litellm/caching.py b/litellm/caching.py
index 5add0cd8e..0a9fef417 100644
--- a/litellm/caching.py
+++ b/litellm/caching.py
@@ -304,40 +304,25 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
)
- def increment_cache(self, key, value: int, **kwargs) -> int:
+ def increment_cache(
+ self, key, value: int, ttl: Optional[float] = None, **kwargs
+ ) -> int:
_redis_client = self.redis_client
start_time = time.time()
try:
result = _redis_client.incr(name=key, amount=value)
- ## LOGGING ##
- end_time = time.time()
- _duration = end_time - start_time
- asyncio.create_task(
- self.service_logger_obj.service_success_hook(
- service=ServiceTypes.REDIS,
- duration=_duration,
- call_type="increment_cache",
- start_time=start_time,
- end_time=end_time,
- parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
- )
- )
+
+ if ttl is not None:
+ # check if key already has ttl, if not -> set ttl
+ current_ttl = _redis_client.ttl(key)
+ if current_ttl == -1:
+ # Key has no expiration
+ _redis_client.expire(key, ttl)
return result
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
- asyncio.create_task(
- self.service_logger_obj.async_service_failure_hook(
- service=ServiceTypes.REDIS,
- duration=_duration,
- error=e,
- call_type="increment_cache",
- start_time=start_time,
- end_time=end_time,
- parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
- )
- )
verbose_logger.error(
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
str(e),
@@ -606,12 +591,22 @@ class RedisCache(BaseCache):
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer() # logging done in here
- async def async_increment(self, key, value: float, **kwargs) -> float:
+ async def async_increment(
+ self, key, value: float, ttl: Optional[float] = None, **kwargs
+ ) -> float:
_redis_client = self.init_async_client()
start_time = time.time()
try:
async with _redis_client as redis_client:
result = await redis_client.incrbyfloat(name=key, amount=value)
+
+ if ttl is not None:
+ # check if key already has ttl, if not -> set ttl
+ current_ttl = await redis_client.ttl(key)
+ if current_ttl == -1:
+ # Key has no expiration
+ await redis_client.expire(key, ttl)
+
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index 0d3e59db5..9528b6fbb 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -1609,15 +1609,24 @@ class Logging:
"""
from litellm.types.router import RouterErrors
+ litellm_params: dict = self.model_call_details.get("litellm_params") or {}
+ metadata = litellm_params.get("metadata") or {}
+
+ ## BASE CASE ## check if rate limit error for model group size 1
+ is_base_case = False
+ if metadata.get("model_group_size") is not None:
+ model_group_size = metadata.get("model_group_size")
+ if isinstance(model_group_size, int) and model_group_size == 1:
+ is_base_case = True
## check if special error ##
- if RouterErrors.no_deployments_available.value not in str(exception):
+ if (
+ RouterErrors.no_deployments_available.value not in str(exception)
+ and is_base_case is False
+ ):
return
## get original model group ##
- litellm_params: dict = self.model_call_details.get("litellm_params") or {}
- metadata = litellm_params.get("metadata") or {}
-
model_group = metadata.get("model_group") or None
for callback in litellm._async_failure_callback:
if isinstance(callback, CustomLogger): # custom logger class
diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html
deleted file mode 100644
index 34d1b613d..000000000
--- a/litellm/proxy/_experimental/out/404.html
+++ /dev/null
@@ -1 +0,0 @@
-
404: This page could not be found.LiteLLM Dashboard404
This page could not be found.
\ No newline at end of file
diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html
deleted file mode 100644
index 07e68f30e..000000000
--- a/litellm/proxy/_experimental/out/model_hub.html
+++ /dev/null
@@ -1 +0,0 @@
-LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html
deleted file mode 100644
index abb658918..000000000
--- a/litellm/proxy/_experimental/out/onboarding.html
+++ /dev/null
@@ -1 +0,0 @@
-LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index da98b6a10..662d6d835 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -386,6 +386,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
- team_allowed_routes: list of allowed routes for proxy team roles.
- user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
+ - user_email_jwt_field: The field in the JWT token that stores the user email (maps to `LiteLLMUserTable`). Use this for internal employees.
+ - user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
@@ -417,6 +419,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
)
org_id_jwt_field: Optional[str] = None
user_id_jwt_field: Optional[str] = None
+ user_email_jwt_field: Optional[str] = None
+ user_allowed_email_domain: Optional[str] = None
user_id_upsert: bool = Field(
default=False, description="If user doesn't exist, upsert them into the db."
)
@@ -1690,6 +1694,9 @@ class SpendLogsMetadata(TypedDict):
Specific metadata k,v pairs logged to spendlogs for easier cost tracking
"""
+ additional_usage_values: Optional[
+ dict
+ ] # covers provider-specific usage information - e.g. prompt caching
user_api_key: Optional[str]
user_api_key_alias: Optional[str]
user_api_key_team_id: Optional[str]
diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py
index f8618781f..b39064ae6 100644
--- a/litellm/proxy/auth/handle_jwt.py
+++ b/litellm/proxy/auth/handle_jwt.py
@@ -78,6 +78,19 @@ class JWTHandler:
return False
return True
+ def is_enforced_email_domain(self) -> bool:
+ """
+ Returns:
+ - True: if 'user_allowed_email_domain' is set
+ - False: if 'user_allowed_email_domain' is None
+ """
+
+ if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance(
+ self.litellm_jwtauth.user_allowed_email_domain, str
+ ):
+ return True
+ return False
+
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.team_id_jwt_field is not None:
@@ -90,12 +103,14 @@ class JWTHandler:
team_id = default_value
return team_id
- def is_upsert_user_id(self) -> bool:
+ def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool:
"""
Returns:
- - True: if 'user_id_upsert' is set
+ - True: if 'user_id_upsert' is set AND valid_user_email is not False
- False: if not
"""
+ if valid_user_email is False:
+ return False
return self.litellm_jwtauth.user_id_upsert
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
@@ -103,11 +118,23 @@ class JWTHandler:
if self.litellm_jwtauth.user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.user_id_jwt_field]
else:
- user_id = None
+ user_id = default_value
except KeyError:
user_id = default_value
return user_id
+ def get_user_email(
+ self, token: dict, default_value: Optional[str]
+ ) -> Optional[str]:
+ try:
+ if self.litellm_jwtauth.user_email_jwt_field is not None:
+ user_email = token[self.litellm_jwtauth.user_email_jwt_field]
+ else:
+ user_email = None
+ except KeyError:
+ user_email = default_value
+ return user_email
+
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.org_id_jwt_field is not None:
@@ -183,6 +210,16 @@ class JWTHandler:
return public_key
+ def is_allowed_domain(self, user_email: str) -> bool:
+ if self.litellm_jwtauth.user_allowed_email_domain is None:
+ return True
+
+ email_domain = user_email.split("@")[-1] # Extract domain from email
+ if email_domain == self.litellm_jwtauth.user_allowed_email_domain:
+ return True
+ else:
+ return False
+
async def auth_jwt(self, token: str) -> dict:
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py
index deee81ffd..114f27d44 100644
--- a/litellm/proxy/auth/user_api_key_auth.py
+++ b/litellm/proxy/auth/user_api_key_auth.py
@@ -250,6 +250,7 @@ async def user_api_key_auth(
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
+
# get team id
team_id = jwt_handler.get_team_id(
token=jwt_valid_token, default_value=None
@@ -296,10 +297,30 @@ async def user_api_key_auth(
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
+ # [OPTIONAL] allowed user email domains
+ valid_user_email: Optional[bool] = None
+ user_email: Optional[str] = None
+ if jwt_handler.is_enforced_email_domain():
+ """
+ if 'allowed_email_subdomains' is set,
+
+ - checks if token contains 'email' field
+ - checks if 'email' is from an allowed domain
+ """
+ user_email = jwt_handler.get_user_email(
+ token=jwt_valid_token, default_value=None
+ )
+ if user_email is None:
+ valid_user_email = False
+ else:
+ valid_user_email = jwt_handler.is_allowed_domain(
+ user_email=user_email
+ )
+
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None
user_id = jwt_handler.get_user_id(
- token=jwt_valid_token, default_value=None
+ token=jwt_valid_token, default_value=user_email
)
if user_id is not None:
# get the user object
@@ -307,11 +328,12 @@ async def user_api_key_auth(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
- user_id_upsert=jwt_handler.is_upsert_user_id(),
+ user_id_upsert=jwt_handler.is_upsert_user_id(
+ valid_user_email=valid_user_email
+ ),
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
-
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None
end_user_id = jwt_handler.get_end_user_id(
@@ -802,7 +824,7 @@ async def user_api_key_auth(
# collect information for alerting #
####################################
- user_email: Optional[str] = None
+ user_email = None
# Check if the token has any user id information
if user_obj is not None:
user_email = user_obj.user_email
diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py
index 890c576c9..4c6172a4d 100644
--- a/litellm/proxy/litellm_pre_call_utils.py
+++ b/litellm/proxy/litellm_pre_call_utils.py
@@ -107,7 +107,16 @@ def _get_dynamic_logging_metadata(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[TeamCallbackMetadata]:
callback_settings_obj: Optional[TeamCallbackMetadata] = None
- if user_api_key_dict.team_metadata is not None:
+ if (
+ user_api_key_dict.metadata is not None
+ and "logging" in user_api_key_dict.metadata
+ ):
+ for item in user_api_key_dict.metadata["logging"]:
+ callback_settings_obj = convert_key_logging_metadata_to_callback(
+ data=AddTeamCallback(**item),
+ team_callback_settings_obj=callback_settings_obj,
+ )
+ elif user_api_key_dict.team_metadata is not None:
team_metadata = user_api_key_dict.team_metadata
if "callback_settings" in team_metadata:
callback_settings = team_metadata.get("callback_settings", None) or {}
@@ -124,15 +133,7 @@ def _get_dynamic_logging_metadata(
}
}
"""
- elif (
- user_api_key_dict.metadata is not None
- and "logging" in user_api_key_dict.metadata
- ):
- for item in user_api_key_dict.metadata["logging"]:
- callback_settings_obj = convert_key_logging_metadata_to_callback(
- data=AddTeamCallback(**item),
- team_callback_settings_obj=callback_settings_obj,
- )
+
return callback_settings_obj
diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py
index a1a0b9733..bdeef92cc 100644
--- a/litellm/proxy/spend_tracking/spend_tracking_utils.py
+++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py
@@ -84,6 +84,7 @@ def get_logging_payload(
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
+ additional_usage_values=None,
)
if isinstance(metadata, dict):
verbose_proxy_logger.debug(
@@ -100,6 +101,13 @@ def get_logging_payload(
}
)
+ special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
+ additional_usage_values = {}
+ for k, v in usage.items():
+ if k not in special_usage_fields:
+ additional_usage_values.update({k: v})
+ clean_metadata["additional_usage_values"] = additional_usage_values
+
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
else:
diff --git a/litellm/router.py b/litellm/router.py
index 5a01f4f39..c187474f1 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -161,10 +161,10 @@ class Router:
enable_tag_filtering: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request
retry_policy: Optional[
- RetryPolicy
+ Union[RetryPolicy, dict]
] = None, # set custom retries for different exceptions
- model_group_retry_policy: Optional[
- Dict[str, RetryPolicy]
+ model_group_retry_policy: Dict[
+ str, RetryPolicy
] = {}, # set custom retry policies based on model group
allowed_fails: Optional[
int
@@ -263,7 +263,7 @@ class Router:
self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks
self.enable_tag_filtering = enable_tag_filtering
- if self.set_verbose == True:
+ if self.set_verbose is True:
if debug_level == "INFO":
verbose_router_logger.setLevel(logging.INFO)
elif debug_level == "DEBUG":
@@ -454,11 +454,35 @@ class Router:
)
self.routing_strategy_args = routing_strategy_args
- self.retry_policy: Optional[RetryPolicy] = retry_policy
+ self.retry_policy: Optional[RetryPolicy] = None
+ if retry_policy is not None:
+ if isinstance(retry_policy, dict):
+ self.retry_policy = RetryPolicy(**retry_policy)
+ elif isinstance(retry_policy, RetryPolicy):
+ self.retry_policy = retry_policy
+ verbose_router_logger.info(
+ "\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format(
+ self.retry_policy.model_dump(exclude_none=True)
+ )
+ )
+
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)
- self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
+
+ self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
+ if allowed_fails_policy is not None:
+ if isinstance(allowed_fails_policy, dict):
+ self.allowed_fails_policy = AllowedFailsPolicy(**allowed_fails_policy)
+ elif isinstance(allowed_fails_policy, AllowedFailsPolicy):
+ self.allowed_fails_policy = allowed_fails_policy
+
+ verbose_router_logger.info(
+ "\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format(
+ self.allowed_fails_policy.model_dump(exclude_none=True)
+ )
+ )
+
self.alerting_config: Optional[AlertingConfig] = alerting_config
if self.alerting_config is not None:
self._initialize_alerting()
@@ -3003,6 +3027,13 @@ class Router:
model_group = kwargs.get("model")
num_retries = kwargs.pop("num_retries")
+ ## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking
+ _metadata: dict = kwargs.get("metadata") or {}
+ if "model_group" in _metadata and isinstance(_metadata["model_group"], str):
+ model_list = self.get_model_list(model_name=_metadata["model_group"])
+ if model_list is not None:
+ _metadata.update({"model_group_size": len(model_list)})
+
verbose_router_logger.debug(
f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
)
@@ -3165,6 +3196,7 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
+
model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get(
@@ -3173,6 +3205,7 @@ class Router:
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
+
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
@@ -3324,6 +3357,9 @@ class Router:
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
)
original_function = kwargs.pop("original_function")
+ mock_testing_rate_limit_error = kwargs.pop(
+ "mock_testing_rate_limit_error", None
+ )
num_retries = kwargs.pop("num_retries")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop(
@@ -3332,9 +3368,22 @@ class Router:
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
+ model_group = kwargs.get("model")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
+ if (
+ mock_testing_rate_limit_error is not None
+ and mock_testing_rate_limit_error is True
+ ):
+ verbose_router_logger.info(
+ "litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError."
+ )
+ raise litellm.RateLimitError(
+ model=model_group,
+ llm_provider="",
+ message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
+ )
response = original_function(*args, **kwargs)
return response
except Exception as e:
@@ -3571,17 +3620,26 @@ class Router:
) # don't change existing ttl
def _is_cooldown_required(
- self, exception_status: Union[str, int], exception_str: Optional[str] = None
- ):
+ self,
+ model_id: str,
+ exception_status: Union[str, int],
+ exception_str: Optional[str] = None,
+ ) -> bool:
"""
A function to determine if a cooldown is required based on the exception status.
Parameters:
+ model_id (str) The id of the model in the model list
exception_status (Union[str, int]): The status of the exception.
Returns:
bool: True if a cooldown is required, False otherwise.
"""
+ ## BASE CASE - single deployment
+ model_group = self.get_model_group(id=model_id)
+ if model_group is not None and len(model_group) == 1:
+ return False
+
try:
ignored_strings = ["APIConnectionError"]
if (
@@ -3677,7 +3735,9 @@ class Router:
if (
self._is_cooldown_required(
- exception_status=exception_status, exception_str=str(original_exception)
+ model_id=deployment,
+ exception_status=exception_status,
+ exception_str=str(original_exception),
)
is False
):
@@ -3690,7 +3750,9 @@ class Router:
exception=original_exception,
)
- allowed_fails = _allowed_fails if _allowed_fails is not None else self.allowed_fails
+ allowed_fails = (
+ _allowed_fails if _allowed_fails is not None else self.allowed_fails
+ )
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
@@ -4298,6 +4360,18 @@ class Router:
return model
return None
+ def get_model_group(self, id: str) -> Optional[List]:
+ """
+ Return list of all models in the same model group as that model id
+ """
+
+ model_info = self.get_model_info(id=id)
+ if model_info is None:
+ return None
+
+ model_name = model_info["model_name"]
+ return self.get_model_list(model_name=model_name)
+
def _set_model_group_info(
self, model_group: str, user_facing_model_group_name: str
) -> Optional[ModelGroupInfo]:
@@ -5528,18 +5602,19 @@ class Router:
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
- retry_policy = self.retry_policy
+ retry_policy: Optional[RetryPolicy] = self.retry_policy
if (
self.model_group_retry_policy is not None
and model_group is not None
and model_group in self.model_group_retry_policy
):
- retry_policy = self.model_group_retry_policy.get(model_group, None)
+ retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore
if retry_policy is None:
return None
if isinstance(retry_policy, dict):
retry_policy = RetryPolicy(**retry_policy)
+
if (
isinstance(exception, litellm.BadRequestError)
and retry_policy.BadRequestErrorRetries is not None
diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py
index e136654c1..5d1f72cf7 100644
--- a/litellm/secret_managers/main.py
+++ b/litellm/secret_managers/main.py
@@ -29,6 +29,27 @@ def _is_base64(s):
return False
+def str_to_bool(value: str) -> Optional[bool]:
+ """
+ Converts a string to a boolean if it's a recognized boolean string.
+ Returns None if the string is not a recognized boolean value.
+
+ :param value: The string to be checked.
+ :return: True or False if the string is a recognized boolean, otherwise None.
+ """
+ true_values = {"true"}
+ false_values = {"false"}
+
+ value_lower = value.strip().lower()
+
+ if value_lower in true_values:
+ return True
+ elif value_lower in false_values:
+ return False
+ else:
+ return None
+
+
def get_secret(
secret_name: str,
default_value: Optional[Union[str, bool]] = None,
@@ -257,17 +278,12 @@ def get_secret(
return secret
else:
secret = os.environ.get(secret_name)
- try:
- secret_value_as_bool = (
- ast.literal_eval(secret) if secret is not None else None
- )
- if isinstance(secret_value_as_bool, bool):
- return secret_value_as_bool
- else:
- return secret
- except Exception:
- if default_value is not None:
- return default_value
+ secret_value_as_bool = str_to_bool(secret) if secret is not None else None
+ if secret_value_as_bool is not None and isinstance(
+ secret_value_as_bool, bool
+ ):
+ return secret_value_as_bool
+ else:
return secret
except Exception as e:
if default_value is not None:
diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py
index de2729db7..91ed7ea4a 100644
--- a/litellm/tests/test_amazing_vertex_completion.py
+++ b/litellm/tests/test_amazing_vertex_completion.py
@@ -54,6 +54,7 @@ VERTEX_MODELS_TO_NOT_TEST = [
"gemini-flash-experimental",
"gemini-1.5-flash-exp-0827",
"gemini-pro-flash",
+ "gemini-1.5-flash-exp-0827",
]
diff --git a/litellm/tests/test_custom_callback_router.py b/litellm/tests/test_custom_callback_router.py
index 80fc096e7..6ffa97d89 100644
--- a/litellm/tests/test_custom_callback_router.py
+++ b/litellm/tests/test_custom_callback_router.py
@@ -38,6 +38,8 @@ from litellm.integrations.custom_logger import CustomLogger
## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings
+litellm.num_retries = 0
+
class CompletionCustomHandler(
CustomLogger
@@ -401,7 +403,7 @@ async def test_async_chat_azure():
"rpm": 1800,
},
]
- router = Router(model_list=model_list) # type: ignore
+ router = Router(model_list=model_list, num_retries=0) # type: ignore
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
@@ -413,7 +415,7 @@ async def test_async_chat_azure():
) # pre, post, success
# streaming
litellm.callbacks = [customHandler_streaming_azure_router]
- router2 = Router(model_list=model_list) # type: ignore
+ router2 = Router(model_list=model_list, num_retries=0) # type: ignore
response = await router2.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
@@ -443,7 +445,7 @@ async def test_async_chat_azure():
},
]
litellm.callbacks = [customHandler_failure]
- router3 = Router(model_list=model_list) # type: ignore
+ router3 = Router(model_list=model_list, num_retries=0) # type: ignore
try:
response = await router3.acompletion(
model="gpt-3.5-turbo",
@@ -505,7 +507,7 @@ async def test_async_embedding_azure():
},
]
litellm.callbacks = [customHandler_failure]
- router3 = Router(model_list=model_list) # type: ignore
+ router3 = Router(model_list=model_list, num_retries=0) # type: ignore
try:
response = await router3.aembedding(
model="azure-embedding-model", input=["hello from litellm!"]
@@ -678,22 +680,21 @@ async def test_rate_limit_error_callback():
pass
with patch.object(
- customHandler, "log_model_group_rate_limit_error", new=MagicMock()
+ customHandler, "log_model_group_rate_limit_error", new=AsyncMock()
) as mock_client:
print(
f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}"
)
- for _ in range(3):
- try:
- _ = await router.acompletion(
- model="my-test-gpt",
- messages=[{"role": "user", "content": "Hey, how's it going?"}],
- litellm_logging_obj=litellm_logging_obj,
- )
- except (litellm.RateLimitError, ValueError):
- pass
+ try:
+ _ = await router.acompletion(
+ model="my-test-gpt",
+ messages=[{"role": "user", "content": "Hey, how's it going?"}],
+ litellm_logging_obj=litellm_logging_obj,
+ )
+ except (litellm.RateLimitError, ValueError):
+ pass
await asyncio.sleep(3)
mock_client.assert_called_once()
diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py
index ddafdb933..51bf55c9c 100644
--- a/litellm/tests/test_jwt.py
+++ b/litellm/tests/test_jwt.py
@@ -23,8 +23,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request
+import litellm
from litellm.caching import DualCache
-from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes
+from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.team_endpoints import new_team
from litellm.proxy.proxy_server import chat_completion
@@ -816,8 +817,6 @@ async def test_allowed_routes_admin(prisma_client, audience):
raise e
-from unittest.mock import AsyncMock
-
import pytest
@@ -844,3 +843,148 @@ async def test_team_cache_update_called():
await asyncio.sleep(3)
mock_call_cache.assert_awaited_once()
+
+
+@pytest.fixture
+def public_jwt_key():
+ import json
+
+ import jwt
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.hazmat.primitives import serialization
+ from cryptography.hazmat.primitives.asymmetric import rsa
+
+ # Generate a private / public key pair using RSA algorithm
+ key = rsa.generate_private_key(
+ public_exponent=65537, key_size=2048, backend=default_backend()
+ )
+ # Get private key in PEM format
+ private_key = key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=serialization.NoEncryption(),
+ )
+
+ # Get public key in PEM format
+ public_key = key.public_key().public_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
+ )
+
+ public_key_obj = serialization.load_pem_public_key(
+ public_key, backend=default_backend()
+ )
+
+ # Convert RSA public key object to JWK (JSON Web Key)
+ public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
+
+ return {"private_key": private_key, "public_jwk": public_jwk}
+
+
+def mock_user_object(*args, **kwargs):
+ print("Args: {}".format(args))
+ print("kwargs: {}".format(kwargs))
+ assert kwargs["user_id_upsert"] is True
+
+
+@pytest.mark.parametrize(
+ "user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)]
+)
+@pytest.mark.asyncio
+async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
+ """
+ Allow anyone with an `@xyz.com` email make a request to the proxy.
+
+ Relevant issue: https://github.com/BerriAI/litellm/issues/5605
+ """
+ import jwt
+ from starlette.datastructures import URL
+
+ from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
+ from litellm.proxy.proxy_server import user_api_key_auth
+
+ public_jwk = public_jwt_key["public_jwk"]
+ private_key = public_jwt_key["private_key"]
+
+ # set cache
+ cache = DualCache()
+
+ await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
+
+ jwt_handler = JWTHandler()
+
+ jwt_handler.user_api_key_cache = cache
+
+ jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
+ user_email_jwt_field="email",
+ user_allowed_email_domain="berri.ai",
+ user_id_upsert=True,
+ )
+
+ # VALID TOKEN
+ ## GENERATE A TOKEN
+ # Assuming the current time is in UTC
+ expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
+
+ team_id = f"team123_{uuid.uuid4()}"
+ payload = {
+ "sub": "user123",
+ "exp": expiration_time, # set the token to expire in 10 minutes
+ "scope": "litellm_team",
+ "client_id": team_id,
+ "aud": "litellm-proxy",
+ "email": user_email,
+ }
+
+ # Generate the JWT token
+ # But before, you should convert bytes to string
+ private_key_str = private_key.decode("utf-8")
+
+ ## team token
+ token = jwt.encode(payload, private_key_str, algorithm="RS256")
+
+ ## VERIFY IT WORKS
+ # Expect the call to succeed
+ response = await jwt_handler.auth_jwt(token=token)
+ assert response is not None # Adjust this based on your actual response check
+
+ ## RUN IT THROUGH USER API KEY AUTH
+ bearer_token = "Bearer " + token
+
+ request = Request(scope={"type": "http"})
+
+ request._url = URL(url="/chat/completions")
+
+ ## 1. INITIAL TEAM CALL - should fail
+ # use generated key to auth in
+ setattr(
+ litellm.proxy.proxy_server,
+ "general_settings",
+ {
+ "enable_jwt_auth": True,
+ },
+ )
+ setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
+ setattr(litellm.proxy.proxy_server, "prisma_client", {})
+
+ # AsyncMock(
+ # return_value=LiteLLM_UserTable(
+ # spend=0, user_id=user_email, max_budget=None, user_email=user_email
+ # )
+ # ),
+ with patch.object(
+ litellm.proxy.auth.user_api_key_auth,
+ "get_user_object",
+ side_effect=mock_user_object,
+ ) as mock_client:
+ if should_work:
+ # Expect the call to succeed
+ result = await user_api_key_auth(request=request, api_key=bearer_token)
+ assert result is not None # Adjust this based on your actual response check
+ else:
+ # Expect the call to fail
+ with pytest.raises(
+ Exception
+ ): # Replace with the actual exception raised on failure
+ resp = await user_api_key_auth(request=request, api_key=bearer_token)
+ print(resp)
diff --git a/litellm/tests/test_lunary.py b/litellm/tests/test_lunary.py
index cd068d990..d181d24c7 100644
--- a/litellm/tests/test_lunary.py
+++ b/litellm/tests/test_lunary.py
@@ -1,16 +1,17 @@
-import sys
-import os
import io
+import os
+import sys
sys.path.insert(0, os.path.abspath("../.."))
-from litellm import completion
import litellm
+from litellm import completion
litellm.failure_callback = ["lunary"]
litellm.success_callback = ["lunary"]
litellm.set_verbose = True
+
def test_lunary_logging():
try:
response = completion(
@@ -24,6 +25,7 @@ def test_lunary_logging():
except Exception as e:
print(e)
+
test_lunary_logging()
@@ -37,8 +39,6 @@ def test_lunary_template():
except Exception as e:
print(e)
-test_lunary_template()
-
def test_lunary_logging_with_metadata():
try:
@@ -50,19 +50,23 @@ def test_lunary_logging_with_metadata():
metadata={
"run_name": "litellmRUN",
"project_name": "litellm-completion",
- "tags": ["tag1", "tag2"]
+ "tags": ["tag1", "tag2"],
},
)
print(response)
except Exception as e:
print(e)
-test_lunary_logging_with_metadata()
def test_lunary_with_tools():
import litellm
- messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
+ messages = [
+ {
+ "role": "user",
+ "content": "What's the weather like in San Francisco, Tokyo, and Paris?",
+ }
+ ]
tools = [
{
"type": "function",
@@ -90,13 +94,11 @@ def test_lunary_with_tools():
tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit
)
-
+
response_message = response.choices[0].message
print("\nLLM Response:\n", response.choices[0].message)
-test_lunary_with_tools()
-
def test_lunary_logging_with_streaming_and_metadata():
try:
response = completion(
@@ -114,5 +116,3 @@ def test_lunary_logging_with_streaming_and_metadata():
continue
except Exception as e:
print(e)
-
-test_lunary_logging_with_streaming_and_metadata()
diff --git a/litellm/tests/test_proxy_utils.py b/litellm/tests/test_proxy_utils.py
index 63361b09a..b5aac09d1 100644
--- a/litellm/tests/test_proxy_utils.py
+++ b/litellm/tests/test_proxy_utils.py
@@ -11,8 +11,11 @@ import litellm
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
-from litellm.proxy._types import UserAPIKeyAuth
-from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
+from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
+from litellm.proxy.litellm_pre_call_utils import (
+ _get_dynamic_logging_metadata,
+ add_litellm_data_to_request,
+)
from litellm.types.utils import SupportedCacheControls
@@ -204,3 +207,87 @@ async def test_add_key_or_team_level_spend_logs_metadata_to_request(
# assert (
# new_data["metadata"]["spend_logs_metadata"] == metadata["spend_logs_metadata"]
# )
+
+
+@pytest.mark.parametrize(
+ "callback_vars",
+ [
+ {
+ "langfuse_host": "https://us.cloud.langfuse.com",
+ "langfuse_public_key": "pk-lf-9636b7a6-c066",
+ "langfuse_secret_key": "sk-lf-7cc8b620",
+ },
+ {
+ "langfuse_host": "os.environ/LANGFUSE_HOST_TEMP",
+ "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY_TEMP",
+ "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY_TEMP",
+ },
+ ],
+)
+def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
+ os.environ["LANGFUSE_PUBLIC_KEY_TEMP"] = "pk-lf-9636b7a6-c066"
+ os.environ["LANGFUSE_SECRET_KEY_TEMP"] = "sk-lf-7cc8b620"
+ os.environ["LANGFUSE_HOST_TEMP"] = "https://us.cloud.langfuse.com"
+ user_api_key_dict = UserAPIKeyAuth(
+ token="6f8688eaff1d37555bb9e9a6390b6d7032b3ab2526ba0152da87128eab956432",
+ key_name="sk-...63Fg",
+ key_alias=None,
+ spend=0.000111,
+ max_budget=None,
+ expires=None,
+ models=[],
+ aliases={},
+ config={},
+ user_id=None,
+ team_id="ishaan-special-team_e02dd54f-f790-4755-9f93-73734f415898",
+ max_parallel_requests=None,
+ metadata={
+ "logging": [
+ {
+ "callback_name": "langfuse",
+ "callback_type": "success",
+ "callback_vars": callback_vars,
+ }
+ ]
+ },
+ tpm_limit=None,
+ rpm_limit=None,
+ budget_duration=None,
+ budget_reset_at=None,
+ allowed_cache_controls=[],
+ permissions={},
+ model_spend={},
+ model_max_budget={},
+ soft_budget_cooldown=False,
+ litellm_budget_table=None,
+ org_id=None,
+ team_spend=0.000132,
+ team_alias=None,
+ team_tpm_limit=None,
+ team_rpm_limit=None,
+ team_max_budget=None,
+ team_models=[],
+ team_blocked=False,
+ soft_budget=None,
+ team_model_aliases=None,
+ team_member_spend=None,
+ team_member=None,
+ team_metadata={},
+ end_user_id=None,
+ end_user_tpm_limit=None,
+ end_user_rpm_limit=None,
+ end_user_max_budget=None,
+ last_refreshed_at=1726101560.967527,
+ api_key="7c305cc48fe72272700dc0d67dc691c2d1f2807490ef5eb2ee1d3a3ca86e12b1",
+ user_role=LitellmUserRoles.INTERNAL_USER,
+ allowed_model_region=None,
+ parent_otel_span=None,
+ rpm_limit_per_model=None,
+ tpm_limit_per_model=None,
+ )
+ callbacks = _get_dynamic_logging_metadata(user_api_key_dict=user_api_key_dict)
+
+ assert callbacks is not None
+
+ for var in callbacks.callback_vars.values():
+ assert "os.environ" not in var
diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py
index fd89130fe..05d9f9f76 100644
--- a/litellm/tests/test_router.py
+++ b/litellm/tests/test_router.py
@@ -2121,7 +2121,7 @@ def test_router_cooldown_api_connection_error():
except litellm.APIConnectionError as e:
assert (
Router()._is_cooldown_required(
- exception_status=e.code, exception_str=str(e)
+ model_id="", exception_status=e.code, exception_str=str(e)
)
is False
)
@@ -2272,7 +2272,13 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
"litellm_params": {
"model": "openai/text-embedding-ada-002",
},
- }
+ },
+ {
+ "model_name": "text-embedding-ada-002",
+ "litellm_params": {
+ "model": "openai/text-embedding-ada-002",
+ },
+ },
]
)
diff --git a/litellm/tests/test_router_cooldowns.py b/litellm/tests/test_router_cooldowns.py
index 3eef6e542..ac92dfbf0 100644
--- a/litellm/tests/test_router_cooldowns.py
+++ b/litellm/tests/test_router_cooldowns.py
@@ -21,6 +21,7 @@ import openai
import litellm
from litellm import Router
from litellm.integrations.custom_logger import CustomLogger
+from litellm.types.router import DeploymentTypedDict, LiteLLMParamsTypedDict
@pytest.mark.asyncio
@@ -112,3 +113,40 @@ async def test_dynamic_cooldowns():
assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"]
assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0
+
+
+@pytest.mark.parametrize("num_deployments", [1, 2])
+def test_single_deployment_no_cooldowns(num_deployments):
+ """
+ Do not cooldown on single deployment.
+
+ Cooldown on multiple deployments.
+ """
+ model_list = []
+ for i in range(num_deployments):
+ model = DeploymentTypedDict(
+ model_name="gpt-3.5-turbo",
+ litellm_params=LiteLLMParamsTypedDict(
+ model="gpt-3.5-turbo",
+ ),
+ )
+ model_list.append(model)
+
+ router = Router(model_list=model_list, allowed_fails=0, num_retries=0)
+
+ with patch.object(
+ router.cooldown_cache, "add_deployment_to_cooldown", new=MagicMock()
+ ) as mock_client:
+ try:
+ router.completion(
+ model="gpt-3.5-turbo",
+ messages=[{"role": "user", "content": "Hey, how's it going?"}],
+ mock_response="litellm.RateLimitError",
+ )
+ except litellm.RateLimitError:
+ pass
+
+ if num_deployments == 1:
+ mock_client.assert_not_called()
+ else:
+ mock_client.assert_called_once()
diff --git a/litellm/tests/test_router_retries.py b/litellm/tests/test_router_retries.py
index f0503cd3f..f4574212d 100644
--- a/litellm/tests/test_router_retries.py
+++ b/litellm/tests/test_router_retries.py
@@ -89,6 +89,17 @@ async def test_router_retries_errors(sync_mode, error_type):
"tpm": 240000,
"rpm": 1800,
},
+ {
+ "model_name": "azure/gpt-3.5-turbo", # openai model name
+ "litellm_params": { # params for litellm completion/embedding call
+ "model": "azure/chatgpt-functioncalling",
+ "api_key": _api_key,
+ "api_version": os.getenv("AZURE_API_VERSION"),
+ "api_base": os.getenv("AZURE_API_BASE"),
+ },
+ "tpm": 240000,
+ "rpm": 1800,
+ },
]
router = Router(model_list=model_list, allowed_fails=3)
diff --git a/litellm/tests/test_tpm_rpm_routing_v2.py b/litellm/tests/test_tpm_rpm_routing_v2.py
index 1f3de0910..259bd0ee0 100644
--- a/litellm/tests/test_tpm_rpm_routing_v2.py
+++ b/litellm/tests/test_tpm_rpm_routing_v2.py
@@ -17,6 +17,8 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
+from unittest.mock import AsyncMock, MagicMock, patch
+
import pytest
import litellm
@@ -459,3 +461,138 @@ async def test_router_completion_streaming():
- Unit test for sync 'pre_call_checks'
- Unit test for async 'async_pre_call_checks'
"""
+
+
+@pytest.mark.asyncio
+async def test_router_caching_ttl():
+ """
+ Confirm caching ttl's work as expected.
+
+ Relevant issue: https://github.com/BerriAI/litellm/issues/5609
+ """
+ messages = [
+ {"role": "user", "content": "Hello, can you generate a 500 words poem?"}
+ ]
+ model = "azure-model"
+ model_list = [
+ {
+ "model_name": "azure-model",
+ "litellm_params": {
+ "model": "azure/gpt-turbo",
+ "api_key": "os.environ/AZURE_FRANCE_API_KEY",
+ "api_base": "https://openai-france-1234.openai.azure.com",
+ "tpm": 1440,
+ "mock_response": "Hello world",
+ },
+ "model_info": {"id": 1},
+ }
+ ]
+ router = Router(
+ model_list=model_list,
+ routing_strategy="usage-based-routing-v2",
+ set_verbose=False,
+ redis_host=os.getenv("REDIS_HOST"),
+ redis_password=os.getenv("REDIS_PASSWORD"),
+ redis_port=os.getenv("REDIS_PORT"),
+ )
+
+ assert router.cache.redis_cache is not None
+
+ increment_cache_kwargs = {}
+ with patch.object(
+ router.cache.redis_cache,
+ "async_increment",
+ new=AsyncMock(),
+ ) as mock_client:
+ await router.acompletion(model=model, messages=messages)
+
+ mock_client.assert_called_once()
+ print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
+ print(f"mock_client.call_args.args: {mock_client.call_args.args}")
+
+ increment_cache_kwargs = {
+ "key": mock_client.call_args.args[0],
+ "value": mock_client.call_args.args[1],
+ "ttl": mock_client.call_args.kwargs["ttl"],
+ }
+
+ assert mock_client.call_args.kwargs["ttl"] == 60
+
+ ## call redis async increment and check if ttl correctly set
+ await router.cache.redis_cache.async_increment(**increment_cache_kwargs)
+
+ _redis_client = router.cache.redis_cache.init_async_client()
+
+ async with _redis_client as redis_client:
+ current_ttl = await redis_client.ttl(increment_cache_kwargs["key"])
+
+ assert current_ttl >= 0
+
+ print(f"current_ttl: {current_ttl}")
+
+
+def test_router_caching_ttl_sync():
+ """
+ Confirm caching ttl's work as expected.
+
+ Relevant issue: https://github.com/BerriAI/litellm/issues/5609
+ """
+ messages = [
+ {"role": "user", "content": "Hello, can you generate a 500 words poem?"}
+ ]
+ model = "azure-model"
+ model_list = [
+ {
+ "model_name": "azure-model",
+ "litellm_params": {
+ "model": "azure/gpt-turbo",
+ "api_key": "os.environ/AZURE_FRANCE_API_KEY",
+ "api_base": "https://openai-france-1234.openai.azure.com",
+ "tpm": 1440,
+ "mock_response": "Hello world",
+ },
+ "model_info": {"id": 1},
+ }
+ ]
+ router = Router(
+ model_list=model_list,
+ routing_strategy="usage-based-routing-v2",
+ set_verbose=False,
+ redis_host=os.getenv("REDIS_HOST"),
+ redis_password=os.getenv("REDIS_PASSWORD"),
+ redis_port=os.getenv("REDIS_PORT"),
+ )
+
+ assert router.cache.redis_cache is not None
+
+ increment_cache_kwargs = {}
+ with patch.object(
+ router.cache.redis_cache,
+ "increment_cache",
+ new=MagicMock(),
+ ) as mock_client:
+ router.completion(model=model, messages=messages)
+
+ print(mock_client.call_args_list)
+ mock_client.assert_called()
+ print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
+ print(f"mock_client.call_args.args: {mock_client.call_args.args}")
+
+ increment_cache_kwargs = {
+ "key": mock_client.call_args.args[0],
+ "value": mock_client.call_args.args[1],
+ "ttl": mock_client.call_args.kwargs["ttl"],
+ }
+
+ assert mock_client.call_args.kwargs["ttl"] == 60
+
+ ## call redis async increment and check if ttl correctly set
+ router.cache.redis_cache.increment_cache(**increment_cache_kwargs)
+
+ _redis_client = router.cache.redis_cache.redis_client
+
+ current_ttl = _redis_client.ttl(increment_cache_kwargs["key"])
+
+ assert current_ttl >= 0
+
+ print(f"current_ttl: {current_ttl}")
diff --git a/litellm/tests/test_traceloop.py b/litellm/tests/test_traceloop.py
index bcc120323..74d58228e 100644
--- a/litellm/tests/test_traceloop.py
+++ b/litellm/tests/test_traceloop.py
@@ -4,7 +4,6 @@ import time
import pytest
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
-from traceloop.sdk import Traceloop
import litellm
@@ -13,6 +12,8 @@ sys.path.insert(0, os.path.abspath("../.."))
@pytest.fixture()
def exporter():
+ from traceloop.sdk import Traceloop
+
exporter = InMemorySpanExporter()
Traceloop.init(
app_name="test_litellm",