forked from phoenix/litellm-mirror
LiteLLM Minor Fixes and Improvements (11/09/2024) (#5634)
* fix(caching.py): set ttl for async_increment cache fixes issue where ttl for redis client was not being set on increment_cache Fixes https://github.com/BerriAI/litellm/issues/5609 * fix(caching.py): fix increment cache w/ ttl for sync increment cache on redis Fixes https://github.com/BerriAI/litellm/issues/5609 * fix(router.py): support adding retry policy + allowed fails policy via config.yaml * fix(router.py): don't cooldown single deployments No point, as there's no other deployment to loadbalance with. * fix(user_api_key_auth.py): support setting allowed email domains on jwt tokens Closes https://github.com/BerriAI/litellm/issues/5605 * docs(token_auth.md): add user upsert + allowed email domain to jwt auth docs * fix(litellm_pre_call_utils.py): fix dynamic key logging when team id is set Fixes issue where key logging would not be set if team metadata was not none * fix(secret_managers/main.py): load environment variables correctly Fixes issue where os.environ/ was not being loaded correctly * test(test_router.py): fix test * feat(spend_tracking_utils.py): support logging additional usage params - e.g. prompt caching values for deepseek * test: fix tests * test: fix test * test: fix test * test: fix test * test: fix test
This commit is contained in:
parent
70100d716b
commit
98c34a7e27
25 changed files with 745 additions and 114 deletions
|
@ -1,12 +1,12 @@
|
|||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: python3 -m mypy --ignore-missing-imports
|
||||
language: system
|
||||
types: [python]
|
||||
files: ^litellm/
|
||||
# - id: mypy
|
||||
# name: mypy
|
||||
# entry: python3 -m mypy --ignore-missing-imports
|
||||
# language: system
|
||||
# types: [python]
|
||||
# files: ^litellm/
|
||||
- id: isort
|
||||
name: isort
|
||||
entry: isort
|
||||
|
|
|
@ -243,3 +243,17 @@ curl --location 'http://0.0.0.0:4000/team/unblock' \
|
|||
}'
|
||||
```
|
||||
|
||||
|
||||
## Advanced - Upsert Users + Allowed Email Domains
|
||||
|
||||
Allow users who belong to a specific email domain, automatic access to the proxy.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
enable_jwt_auth: True
|
||||
litellm_jwtauth:
|
||||
user_email_jwt_field: "email" # 👈 checks 'email' field in jwt payload
|
||||
user_allowed_email_domain: "my-co.com" # allows user@my-co.com to call proxy
|
||||
user_id_upsert: true # 👈 upserts the user to db, if valid email but not in db
|
||||
```
|
|
@ -1038,6 +1038,12 @@ print(f"response: {response}")
|
|||
- Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
|
||||
- Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment
|
||||
|
||||
[**See All Exception Types**](https://github.com/BerriAI/litellm/blob/ccda616f2f881375d4e8586c76fe4662909a7d22/litellm/types/router.py#L436)
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
@ -1101,6 +1107,24 @@ response = await router.acompletion(
|
|||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```yaml
|
||||
router_settings:
|
||||
retry_policy: {
|
||||
"BadRequestErrorRetries": 3,
|
||||
"ContentPolicyViolationErrorRetries": 4
|
||||
}
|
||||
allowed_fails_policy: {
|
||||
"ContentPolicyViolationErrorAllowedFails": 1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
|
||||
"RateLimitErrorAllowedFails": 100 # Allow 100 RateLimitErrors before cooling down a deployment
|
||||
}
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Fallbacks
|
||||
|
||||
|
|
|
@ -304,40 +304,25 @@ class RedisCache(BaseCache):
|
|||
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
|
||||
)
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
def increment_cache(
|
||||
self, key, value: int, ttl: Optional[float] = None, **kwargs
|
||||
) -> int:
|
||||
_redis_client = self.redis_client
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = _redis_client.incr(name=key, amount=value)
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="increment_cache",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||
)
|
||||
)
|
||||
|
||||
if ttl is not None:
|
||||
# check if key already has ttl, if not -> set ttl
|
||||
current_ttl = _redis_client.ttl(key)
|
||||
if current_ttl == -1:
|
||||
# Key has no expiration
|
||||
_redis_client.expire(key, ttl)
|
||||
return result
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="increment_cache",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||
)
|
||||
)
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
|
@ -606,12 +591,22 @@ class RedisCache(BaseCache):
|
|||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||
await self.flush_cache_buffer() # logging done in here
|
||||
|
||||
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||
async def async_increment(
|
||||
self, key, value: float, ttl: Optional[float] = None, **kwargs
|
||||
) -> float:
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
result = await redis_client.incrbyfloat(name=key, amount=value)
|
||||
|
||||
if ttl is not None:
|
||||
# check if key already has ttl, if not -> set ttl
|
||||
current_ttl = await redis_client.ttl(key)
|
||||
if current_ttl == -1:
|
||||
# Key has no expiration
|
||||
await redis_client.expire(key, ttl)
|
||||
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
|
|
|
@ -1609,15 +1609,24 @@ class Logging:
|
|||
"""
|
||||
from litellm.types.router import RouterErrors
|
||||
|
||||
litellm_params: dict = self.model_call_details.get("litellm_params") or {}
|
||||
metadata = litellm_params.get("metadata") or {}
|
||||
|
||||
## BASE CASE ## check if rate limit error for model group size 1
|
||||
is_base_case = False
|
||||
if metadata.get("model_group_size") is not None:
|
||||
model_group_size = metadata.get("model_group_size")
|
||||
if isinstance(model_group_size, int) and model_group_size == 1:
|
||||
is_base_case = True
|
||||
## check if special error ##
|
||||
if RouterErrors.no_deployments_available.value not in str(exception):
|
||||
if (
|
||||
RouterErrors.no_deployments_available.value not in str(exception)
|
||||
and is_base_case is False
|
||||
):
|
||||
return
|
||||
|
||||
## get original model group ##
|
||||
|
||||
litellm_params: dict = self.model_call_details.get("litellm_params") or {}
|
||||
metadata = litellm_params.get("metadata") or {}
|
||||
|
||||
model_group = metadata.get("model_group") or None
|
||||
for callback in litellm._async_failure_callback:
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -386,6 +386,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
|||
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
|
||||
- team_allowed_routes: list of allowed routes for proxy team roles.
|
||||
- user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
|
||||
- user_email_jwt_field: The field in the JWT token that stores the user email (maps to `LiteLLMUserTable`). Use this for internal employees.
|
||||
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
|
||||
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
|
||||
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
|
||||
|
||||
|
@ -417,6 +419,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
|||
)
|
||||
org_id_jwt_field: Optional[str] = None
|
||||
user_id_jwt_field: Optional[str] = None
|
||||
user_email_jwt_field: Optional[str] = None
|
||||
user_allowed_email_domain: Optional[str] = None
|
||||
user_id_upsert: bool = Field(
|
||||
default=False, description="If user doesn't exist, upsert them into the db."
|
||||
)
|
||||
|
@ -1690,6 +1694,9 @@ class SpendLogsMetadata(TypedDict):
|
|||
Specific metadata k,v pairs logged to spendlogs for easier cost tracking
|
||||
"""
|
||||
|
||||
additional_usage_values: Optional[
|
||||
dict
|
||||
] # covers provider-specific usage information - e.g. prompt caching
|
||||
user_api_key: Optional[str]
|
||||
user_api_key_alias: Optional[str]
|
||||
user_api_key_team_id: Optional[str]
|
||||
|
|
|
@ -78,6 +78,19 @@ class JWTHandler:
|
|||
return False
|
||||
return True
|
||||
|
||||
def is_enforced_email_domain(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True: if 'user_allowed_email_domain' is set
|
||||
- False: if 'user_allowed_email_domain' is None
|
||||
"""
|
||||
|
||||
if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance(
|
||||
self.litellm_jwtauth.user_allowed_email_domain, str
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.team_id_jwt_field is not None:
|
||||
|
@ -90,12 +103,14 @@ class JWTHandler:
|
|||
team_id = default_value
|
||||
return team_id
|
||||
|
||||
def is_upsert_user_id(self) -> bool:
|
||||
def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True: if 'user_id_upsert' is set
|
||||
- True: if 'user_id_upsert' is set AND valid_user_email is not False
|
||||
- False: if not
|
||||
"""
|
||||
if valid_user_email is False:
|
||||
return False
|
||||
return self.litellm_jwtauth.user_id_upsert
|
||||
|
||||
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
|
@ -103,11 +118,23 @@ class JWTHandler:
|
|||
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
||||
user_id = token[self.litellm_jwtauth.user_id_jwt_field]
|
||||
else:
|
||||
user_id = None
|
||||
user_id = default_value
|
||||
except KeyError:
|
||||
user_id = default_value
|
||||
return user_id
|
||||
|
||||
def get_user_email(
|
||||
self, token: dict, default_value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.user_email_jwt_field is not None:
|
||||
user_email = token[self.litellm_jwtauth.user_email_jwt_field]
|
||||
else:
|
||||
user_email = None
|
||||
except KeyError:
|
||||
user_email = default_value
|
||||
return user_email
|
||||
|
||||
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.org_id_jwt_field is not None:
|
||||
|
@ -183,6 +210,16 @@ class JWTHandler:
|
|||
|
||||
return public_key
|
||||
|
||||
def is_allowed_domain(self, user_email: str) -> bool:
|
||||
if self.litellm_jwtauth.user_allowed_email_domain is None:
|
||||
return True
|
||||
|
||||
email_domain = user_email.split("@")[-1] # Extract domain from email
|
||||
if email_domain == self.litellm_jwtauth.user_allowed_email_domain:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def auth_jwt(self, token: str) -> dict:
|
||||
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
|
||||
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
|
||||
|
|
|
@ -250,6 +250,7 @@ async def user_api_key_auth(
|
|||
raise Exception(
|
||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
# get team id
|
||||
team_id = jwt_handler.get_team_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
|
@ -296,10 +297,30 @@ async def user_api_key_auth(
|
|||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# [OPTIONAL] allowed user email domains
|
||||
valid_user_email: Optional[bool] = None
|
||||
user_email: Optional[str] = None
|
||||
if jwt_handler.is_enforced_email_domain():
|
||||
"""
|
||||
if 'allowed_email_subdomains' is set,
|
||||
|
||||
- checks if token contains 'email' field
|
||||
- checks if 'email' is from an allowed domain
|
||||
"""
|
||||
user_email = jwt_handler.get_user_email(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
if user_email is None:
|
||||
valid_user_email = False
|
||||
else:
|
||||
valid_user_email = jwt_handler.is_allowed_domain(
|
||||
user_email=user_email
|
||||
)
|
||||
|
||||
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||
user_object = None
|
||||
user_id = jwt_handler.get_user_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
token=jwt_valid_token, default_value=user_email
|
||||
)
|
||||
if user_id is not None:
|
||||
# get the user object
|
||||
|
@ -307,11 +328,12 @@ async def user_api_key_auth(
|
|||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=jwt_handler.is_upsert_user_id(),
|
||||
user_id_upsert=jwt_handler.is_upsert_user_id(
|
||||
valid_user_email=valid_user_email
|
||||
),
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||
end_user_object = None
|
||||
end_user_id = jwt_handler.get_end_user_id(
|
||||
|
@ -802,7 +824,7 @@ async def user_api_key_auth(
|
|||
# collect information for alerting #
|
||||
####################################
|
||||
|
||||
user_email: Optional[str] = None
|
||||
user_email = None
|
||||
# Check if the token has any user id information
|
||||
if user_obj is not None:
|
||||
user_email = user_obj.user_email
|
||||
|
|
|
@ -107,7 +107,16 @@ def _get_dynamic_logging_metadata(
|
|||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[TeamCallbackMetadata]:
|
||||
callback_settings_obj: Optional[TeamCallbackMetadata] = None
|
||||
if user_api_key_dict.team_metadata is not None:
|
||||
if (
|
||||
user_api_key_dict.metadata is not None
|
||||
and "logging" in user_api_key_dict.metadata
|
||||
):
|
||||
for item in user_api_key_dict.metadata["logging"]:
|
||||
callback_settings_obj = convert_key_logging_metadata_to_callback(
|
||||
data=AddTeamCallback(**item),
|
||||
team_callback_settings_obj=callback_settings_obj,
|
||||
)
|
||||
elif user_api_key_dict.team_metadata is not None:
|
||||
team_metadata = user_api_key_dict.team_metadata
|
||||
if "callback_settings" in team_metadata:
|
||||
callback_settings = team_metadata.get("callback_settings", None) or {}
|
||||
|
@ -124,15 +133,7 @@ def _get_dynamic_logging_metadata(
|
|||
}
|
||||
}
|
||||
"""
|
||||
elif (
|
||||
user_api_key_dict.metadata is not None
|
||||
and "logging" in user_api_key_dict.metadata
|
||||
):
|
||||
for item in user_api_key_dict.metadata["logging"]:
|
||||
callback_settings_obj = convert_key_logging_metadata_to_callback(
|
||||
data=AddTeamCallback(**item),
|
||||
team_callback_settings_obj=callback_settings_obj,
|
||||
)
|
||||
|
||||
return callback_settings_obj
|
||||
|
||||
|
||||
|
|
|
@ -84,6 +84,7 @@ def get_logging_payload(
|
|||
user_api_key_team_alias=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
additional_usage_values=None,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -100,6 +101,13 @@ def get_logging_payload(
|
|||
}
|
||||
)
|
||||
|
||||
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
||||
additional_usage_values = {}
|
||||
for k, v in usage.items():
|
||||
if k not in special_usage_fields:
|
||||
additional_usage_values.update({k: v})
|
||||
clean_metadata["additional_usage_values"] = additional_usage_values
|
||||
|
||||
if litellm.cache is not None:
|
||||
cache_key = litellm.cache.get_cache_key(**kwargs)
|
||||
else:
|
||||
|
|
|
@ -161,10 +161,10 @@ class Router:
|
|||
enable_tag_filtering: bool = False,
|
||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||
retry_policy: Optional[
|
||||
RetryPolicy
|
||||
Union[RetryPolicy, dict]
|
||||
] = None, # set custom retries for different exceptions
|
||||
model_group_retry_policy: Optional[
|
||||
Dict[str, RetryPolicy]
|
||||
model_group_retry_policy: Dict[
|
||||
str, RetryPolicy
|
||||
] = {}, # set custom retry policies based on model group
|
||||
allowed_fails: Optional[
|
||||
int
|
||||
|
@ -263,7 +263,7 @@ class Router:
|
|||
self.debug_level = debug_level
|
||||
self.enable_pre_call_checks = enable_pre_call_checks
|
||||
self.enable_tag_filtering = enable_tag_filtering
|
||||
if self.set_verbose == True:
|
||||
if self.set_verbose is True:
|
||||
if debug_level == "INFO":
|
||||
verbose_router_logger.setLevel(logging.INFO)
|
||||
elif debug_level == "DEBUG":
|
||||
|
@ -454,11 +454,35 @@ class Router:
|
|||
)
|
||||
|
||||
self.routing_strategy_args = routing_strategy_args
|
||||
self.retry_policy: Optional[RetryPolicy] = retry_policy
|
||||
self.retry_policy: Optional[RetryPolicy] = None
|
||||
if retry_policy is not None:
|
||||
if isinstance(retry_policy, dict):
|
||||
self.retry_policy = RetryPolicy(**retry_policy)
|
||||
elif isinstance(retry_policy, RetryPolicy):
|
||||
self.retry_policy = retry_policy
|
||||
verbose_router_logger.info(
|
||||
"\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format(
|
||||
self.retry_policy.model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
|
||||
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
||||
model_group_retry_policy
|
||||
)
|
||||
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
|
||||
|
||||
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
|
||||
if allowed_fails_policy is not None:
|
||||
if isinstance(allowed_fails_policy, dict):
|
||||
self.allowed_fails_policy = AllowedFailsPolicy(**allowed_fails_policy)
|
||||
elif isinstance(allowed_fails_policy, AllowedFailsPolicy):
|
||||
self.allowed_fails_policy = allowed_fails_policy
|
||||
|
||||
verbose_router_logger.info(
|
||||
"\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format(
|
||||
self.allowed_fails_policy.model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
|
||||
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
||||
if self.alerting_config is not None:
|
||||
self._initialize_alerting()
|
||||
|
@ -3003,6 +3027,13 @@ class Router:
|
|||
model_group = kwargs.get("model")
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
|
||||
## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking
|
||||
_metadata: dict = kwargs.get("metadata") or {}
|
||||
if "model_group" in _metadata and isinstance(_metadata["model_group"], str):
|
||||
model_list = self.get_model_list(model_name=_metadata["model_group"])
|
||||
if model_list is not None:
|
||||
_metadata.update({"model_group_size": len(model_list)})
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
|
||||
)
|
||||
|
@ -3165,6 +3196,7 @@ class Router:
|
|||
If it fails after num_retries, fall back to another model group
|
||||
"""
|
||||
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
|
||||
|
||||
model_group = kwargs.get("model")
|
||||
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.get(
|
||||
|
@ -3173,6 +3205,7 @@ class Router:
|
|||
content_policy_fallbacks = kwargs.get(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
|
||||
try:
|
||||
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||
raise Exception(
|
||||
|
@ -3324,6 +3357,9 @@ class Router:
|
|||
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
|
||||
)
|
||||
original_function = kwargs.pop("original_function")
|
||||
mock_testing_rate_limit_error = kwargs.pop(
|
||||
"mock_testing_rate_limit_error", None
|
||||
)
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.pop(
|
||||
|
@ -3332,9 +3368,22 @@ class Router:
|
|||
content_policy_fallbacks = kwargs.pop(
|
||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
model_group = kwargs.get("model")
|
||||
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
if (
|
||||
mock_testing_rate_limit_error is not None
|
||||
and mock_testing_rate_limit_error is True
|
||||
):
|
||||
verbose_router_logger.info(
|
||||
"litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError."
|
||||
)
|
||||
raise litellm.RateLimitError(
|
||||
model=model_group,
|
||||
llm_provider="",
|
||||
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
|
||||
)
|
||||
response = original_function(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -3571,17 +3620,26 @@ class Router:
|
|||
) # don't change existing ttl
|
||||
|
||||
def _is_cooldown_required(
|
||||
self, exception_status: Union[str, int], exception_str: Optional[str] = None
|
||||
):
|
||||
self,
|
||||
model_id: str,
|
||||
exception_status: Union[str, int],
|
||||
exception_str: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
A function to determine if a cooldown is required based on the exception status.
|
||||
|
||||
Parameters:
|
||||
model_id (str) The id of the model in the model list
|
||||
exception_status (Union[str, int]): The status of the exception.
|
||||
|
||||
Returns:
|
||||
bool: True if a cooldown is required, False otherwise.
|
||||
"""
|
||||
## BASE CASE - single deployment
|
||||
model_group = self.get_model_group(id=model_id)
|
||||
if model_group is not None and len(model_group) == 1:
|
||||
return False
|
||||
|
||||
try:
|
||||
ignored_strings = ["APIConnectionError"]
|
||||
if (
|
||||
|
@ -3677,7 +3735,9 @@ class Router:
|
|||
|
||||
if (
|
||||
self._is_cooldown_required(
|
||||
exception_status=exception_status, exception_str=str(original_exception)
|
||||
model_id=deployment,
|
||||
exception_status=exception_status,
|
||||
exception_str=str(original_exception),
|
||||
)
|
||||
is False
|
||||
):
|
||||
|
@ -3690,7 +3750,9 @@ class Router:
|
|||
exception=original_exception,
|
||||
)
|
||||
|
||||
allowed_fails = _allowed_fails if _allowed_fails is not None else self.allowed_fails
|
||||
allowed_fails = (
|
||||
_allowed_fails if _allowed_fails is not None else self.allowed_fails
|
||||
)
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
|
@ -4298,6 +4360,18 @@ class Router:
|
|||
return model
|
||||
return None
|
||||
|
||||
def get_model_group(self, id: str) -> Optional[List]:
|
||||
"""
|
||||
Return list of all models in the same model group as that model id
|
||||
"""
|
||||
|
||||
model_info = self.get_model_info(id=id)
|
||||
if model_info is None:
|
||||
return None
|
||||
|
||||
model_name = model_info["model_name"]
|
||||
return self.get_model_list(model_name=model_name)
|
||||
|
||||
def _set_model_group_info(
|
||||
self, model_group: str, user_facing_model_group_name: str
|
||||
) -> Optional[ModelGroupInfo]:
|
||||
|
@ -5528,18 +5602,19 @@ class Router:
|
|||
ContentPolicyViolationErrorRetries: Optional[int] = None
|
||||
"""
|
||||
# if we can find the exception then in the retry policy -> return the number of retries
|
||||
retry_policy = self.retry_policy
|
||||
retry_policy: Optional[RetryPolicy] = self.retry_policy
|
||||
if (
|
||||
self.model_group_retry_policy is not None
|
||||
and model_group is not None
|
||||
and model_group in self.model_group_retry_policy
|
||||
):
|
||||
retry_policy = self.model_group_retry_policy.get(model_group, None)
|
||||
retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore
|
||||
|
||||
if retry_policy is None:
|
||||
return None
|
||||
if isinstance(retry_policy, dict):
|
||||
retry_policy = RetryPolicy(**retry_policy)
|
||||
|
||||
if (
|
||||
isinstance(exception, litellm.BadRequestError)
|
||||
and retry_policy.BadRequestErrorRetries is not None
|
||||
|
|
|
@ -29,6 +29,27 @@ def _is_base64(s):
|
|||
return False
|
||||
|
||||
|
||||
def str_to_bool(value: str) -> Optional[bool]:
|
||||
"""
|
||||
Converts a string to a boolean if it's a recognized boolean string.
|
||||
Returns None if the string is not a recognized boolean value.
|
||||
|
||||
:param value: The string to be checked.
|
||||
:return: True or False if the string is a recognized boolean, otherwise None.
|
||||
"""
|
||||
true_values = {"true"}
|
||||
false_values = {"false"}
|
||||
|
||||
value_lower = value.strip().lower()
|
||||
|
||||
if value_lower in true_values:
|
||||
return True
|
||||
elif value_lower in false_values:
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_secret(
|
||||
secret_name: str,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
|
@ -257,17 +278,12 @@ def get_secret(
|
|||
return secret
|
||||
else:
|
||||
secret = os.environ.get(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = (
|
||||
ast.literal_eval(secret) if secret is not None else None
|
||||
)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except Exception:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
secret_value_as_bool = str_to_bool(secret) if secret is not None else None
|
||||
if secret_value_as_bool is not None and isinstance(
|
||||
secret_value_as_bool, bool
|
||||
):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except Exception as e:
|
||||
if default_value is not None:
|
||||
|
|
|
@ -54,6 +54,7 @@ VERTEX_MODELS_TO_NOT_TEST = [
|
|||
"gemini-flash-experimental",
|
||||
"gemini-1.5-flash-exp-0827",
|
||||
"gemini-pro-flash",
|
||||
"gemini-1.5-flash-exp-0827",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -38,6 +38,8 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
## 1. router.completion() + router.embeddings()
|
||||
## 2. proxy.completions + proxy.embeddings
|
||||
|
||||
litellm.num_retries = 0
|
||||
|
||||
|
||||
class CompletionCustomHandler(
|
||||
CustomLogger
|
||||
|
@ -401,7 +403,7 @@ async def test_async_chat_azure():
|
|||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
router = Router(model_list=model_list, num_retries=0) # type: ignore
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||
|
@ -413,7 +415,7 @@ async def test_async_chat_azure():
|
|||
) # pre, post, success
|
||||
# streaming
|
||||
litellm.callbacks = [customHandler_streaming_azure_router]
|
||||
router2 = Router(model_list=model_list) # type: ignore
|
||||
router2 = Router(model_list=model_list, num_retries=0) # type: ignore
|
||||
response = await router2.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||
|
@ -443,7 +445,7 @@ async def test_async_chat_azure():
|
|||
},
|
||||
]
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
router3 = Router(model_list=model_list) # type: ignore
|
||||
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
|
||||
try:
|
||||
response = await router3.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
|
@ -505,7 +507,7 @@ async def test_async_embedding_azure():
|
|||
},
|
||||
]
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
router3 = Router(model_list=model_list) # type: ignore
|
||||
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
|
||||
try:
|
||||
response = await router3.aembedding(
|
||||
model="azure-embedding-model", input=["hello from litellm!"]
|
||||
|
@ -678,22 +680,21 @@ async def test_rate_limit_error_callback():
|
|||
pass
|
||||
|
||||
with patch.object(
|
||||
customHandler, "log_model_group_rate_limit_error", new=MagicMock()
|
||||
customHandler, "log_model_group_rate_limit_error", new=AsyncMock()
|
||||
) as mock_client:
|
||||
|
||||
print(
|
||||
f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}"
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
try:
|
||||
_ = await router.acompletion(
|
||||
model="my-test-gpt",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
except (litellm.RateLimitError, ValueError):
|
||||
pass
|
||||
try:
|
||||
_ = await router.acompletion(
|
||||
model="my-test-gpt",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
except (litellm.RateLimitError, ValueError):
|
||||
pass
|
||||
|
||||
await asyncio.sleep(3)
|
||||
mock_client.assert_called_once()
|
||||
|
|
|
@ -23,8 +23,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import pytest
|
||||
from fastapi import Request
|
||||
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.management_endpoints.team_endpoints import new_team
|
||||
from litellm.proxy.proxy_server import chat_completion
|
||||
|
@ -816,8 +817,6 @@ async def test_allowed_routes_admin(prisma_client, audience):
|
|||
raise e
|
||||
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
|
@ -844,3 +843,148 @@ async def test_team_cache_update_called():
|
|||
|
||||
await asyncio.sleep(3)
|
||||
mock_call_cache.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def public_jwt_key():
|
||||
import json
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
# Generate a private / public key pair using RSA algorithm
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
)
|
||||
# Get private key in PEM format
|
||||
private_key = key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
# Get public key in PEM format
|
||||
public_key = key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
public_key_obj = serialization.load_pem_public_key(
|
||||
public_key, backend=default_backend()
|
||||
)
|
||||
|
||||
# Convert RSA public key object to JWK (JSON Web Key)
|
||||
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
|
||||
|
||||
return {"private_key": private_key, "public_jwk": public_jwk}
|
||||
|
||||
|
||||
def mock_user_object(*args, **kwargs):
|
||||
print("Args: {}".format(args))
|
||||
print("kwargs: {}".format(kwargs))
|
||||
assert kwargs["user_id_upsert"] is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)]
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
|
||||
"""
|
||||
Allow anyone with an `@xyz.com` email make a request to the proxy.
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/5605
|
||||
"""
|
||||
import jwt
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
|
||||
from litellm.proxy.proxy_server import user_api_key_auth
|
||||
|
||||
public_jwk = public_jwt_key["public_jwk"]
|
||||
private_key = public_jwt_key["private_key"]
|
||||
|
||||
# set cache
|
||||
cache = DualCache()
|
||||
|
||||
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
jwt_handler.user_api_key_cache = cache
|
||||
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
user_email_jwt_field="email",
|
||||
user_allowed_email_domain="berri.ai",
|
||||
user_id_upsert=True,
|
||||
)
|
||||
|
||||
# VALID TOKEN
|
||||
## GENERATE A TOKEN
|
||||
# Assuming the current time is in UTC
|
||||
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
|
||||
|
||||
team_id = f"team123_{uuid.uuid4()}"
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_team",
|
||||
"client_id": team_id,
|
||||
"aud": "litellm-proxy",
|
||||
"email": user_email,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
# But before, you should convert bytes to string
|
||||
private_key_str = private_key.decode("utf-8")
|
||||
|
||||
## team token
|
||||
token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||
|
||||
## VERIFY IT WORKS
|
||||
# Expect the call to succeed
|
||||
response = await jwt_handler.auth_jwt(token=token)
|
||||
assert response is not None # Adjust this based on your actual response check
|
||||
|
||||
## RUN IT THROUGH USER API KEY AUTH
|
||||
bearer_token = "Bearer " + token
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
## 1. INITIAL TEAM CALL - should fail
|
||||
# use generated key to auth in
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"general_settings",
|
||||
{
|
||||
"enable_jwt_auth": True,
|
||||
},
|
||||
)
|
||||
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", {})
|
||||
|
||||
# AsyncMock(
|
||||
# return_value=LiteLLM_UserTable(
|
||||
# spend=0, user_id=user_email, max_budget=None, user_email=user_email
|
||||
# )
|
||||
# ),
|
||||
with patch.object(
|
||||
litellm.proxy.auth.user_api_key_auth,
|
||||
"get_user_object",
|
||||
side_effect=mock_user_object,
|
||||
) as mock_client:
|
||||
if should_work:
|
||||
# Expect the call to succeed
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
assert result is not None # Adjust this based on your actual response check
|
||||
else:
|
||||
# Expect the call to fail
|
||||
with pytest.raises(
|
||||
Exception
|
||||
): # Replace with the actual exception raised on failure
|
||||
resp = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
print(resp)
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import sys
|
||||
import os
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
from litellm import completion
|
||||
import litellm
|
||||
from litellm import completion
|
||||
|
||||
litellm.failure_callback = ["lunary"]
|
||||
litellm.success_callback = ["lunary"]
|
||||
litellm.set_verbose = True
|
||||
|
||||
|
||||
def test_lunary_logging():
|
||||
try:
|
||||
response = completion(
|
||||
|
@ -24,6 +25,7 @@ def test_lunary_logging():
|
|||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
test_lunary_logging()
|
||||
|
||||
|
||||
|
@ -37,8 +39,6 @@ def test_lunary_template():
|
|||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
test_lunary_template()
|
||||
|
||||
|
||||
def test_lunary_logging_with_metadata():
|
||||
try:
|
||||
|
@ -50,19 +50,23 @@ def test_lunary_logging_with_metadata():
|
|||
metadata={
|
||||
"run_name": "litellmRUN",
|
||||
"project_name": "litellm-completion",
|
||||
"tags": ["tag1", "tag2"]
|
||||
"tags": ["tag1", "tag2"],
|
||||
},
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
test_lunary_logging_with_metadata()
|
||||
|
||||
def test_lunary_with_tools():
|
||||
import litellm
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
|
||||
}
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
|
@ -90,13 +94,11 @@ def test_lunary_with_tools():
|
|||
tools=tools,
|
||||
tool_choice="auto", # auto is default, but we'll be explicit
|
||||
)
|
||||
|
||||
|
||||
response_message = response.choices[0].message
|
||||
print("\nLLM Response:\n", response.choices[0].message)
|
||||
|
||||
|
||||
test_lunary_with_tools()
|
||||
|
||||
def test_lunary_logging_with_streaming_and_metadata():
|
||||
try:
|
||||
response = completion(
|
||||
|
@ -114,5 +116,3 @@ def test_lunary_logging_with_streaming_and_metadata():
|
|||
continue
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
test_lunary_logging_with_streaming_and_metadata()
|
||||
|
|
|
@ -11,8 +11,11 @@ import litellm
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.litellm_pre_call_utils import (
|
||||
_get_dynamic_logging_metadata,
|
||||
add_litellm_data_to_request,
|
||||
)
|
||||
from litellm.types.utils import SupportedCacheControls
|
||||
|
||||
|
||||
|
@ -204,3 +207,87 @@ async def test_add_key_or_team_level_spend_logs_metadata_to_request(
|
|||
# assert (
|
||||
# new_data["metadata"]["spend_logs_metadata"] == metadata["spend_logs_metadata"]
|
||||
# )
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callback_vars",
|
||||
[
|
||||
{
|
||||
"langfuse_host": "https://us.cloud.langfuse.com",
|
||||
"langfuse_public_key": "pk-lf-9636b7a6-c066",
|
||||
"langfuse_secret_key": "sk-lf-7cc8b620",
|
||||
},
|
||||
{
|
||||
"langfuse_host": "os.environ/LANGFUSE_HOST_TEMP",
|
||||
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY_TEMP",
|
||||
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY_TEMP",
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
|
||||
os.environ["LANGFUSE_PUBLIC_KEY_TEMP"] = "pk-lf-9636b7a6-c066"
|
||||
os.environ["LANGFUSE_SECRET_KEY_TEMP"] = "sk-lf-7cc8b620"
|
||||
os.environ["LANGFUSE_HOST_TEMP"] = "https://us.cloud.langfuse.com"
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
token="6f8688eaff1d37555bb9e9a6390b6d7032b3ab2526ba0152da87128eab956432",
|
||||
key_name="sk-...63Fg",
|
||||
key_alias=None,
|
||||
spend=0.000111,
|
||||
max_budget=None,
|
||||
expires=None,
|
||||
models=[],
|
||||
aliases={},
|
||||
config={},
|
||||
user_id=None,
|
||||
team_id="ishaan-special-team_e02dd54f-f790-4755-9f93-73734f415898",
|
||||
max_parallel_requests=None,
|
||||
metadata={
|
||||
"logging": [
|
||||
{
|
||||
"callback_name": "langfuse",
|
||||
"callback_type": "success",
|
||||
"callback_vars": callback_vars,
|
||||
}
|
||||
]
|
||||
},
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
budget_duration=None,
|
||||
budget_reset_at=None,
|
||||
allowed_cache_controls=[],
|
||||
permissions={},
|
||||
model_spend={},
|
||||
model_max_budget={},
|
||||
soft_budget_cooldown=False,
|
||||
litellm_budget_table=None,
|
||||
org_id=None,
|
||||
team_spend=0.000132,
|
||||
team_alias=None,
|
||||
team_tpm_limit=None,
|
||||
team_rpm_limit=None,
|
||||
team_max_budget=None,
|
||||
team_models=[],
|
||||
team_blocked=False,
|
||||
soft_budget=None,
|
||||
team_model_aliases=None,
|
||||
team_member_spend=None,
|
||||
team_member=None,
|
||||
team_metadata={},
|
||||
end_user_id=None,
|
||||
end_user_tpm_limit=None,
|
||||
end_user_rpm_limit=None,
|
||||
end_user_max_budget=None,
|
||||
last_refreshed_at=1726101560.967527,
|
||||
api_key="7c305cc48fe72272700dc0d67dc691c2d1f2807490ef5eb2ee1d3a3ca86e12b1",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
allowed_model_region=None,
|
||||
parent_otel_span=None,
|
||||
rpm_limit_per_model=None,
|
||||
tpm_limit_per_model=None,
|
||||
)
|
||||
callbacks = _get_dynamic_logging_metadata(user_api_key_dict=user_api_key_dict)
|
||||
|
||||
assert callbacks is not None
|
||||
|
||||
for var in callbacks.callback_vars.values():
|
||||
assert "os.environ" not in var
|
||||
|
|
|
@ -2121,7 +2121,7 @@ def test_router_cooldown_api_connection_error():
|
|||
except litellm.APIConnectionError as e:
|
||||
assert (
|
||||
Router()._is_cooldown_required(
|
||||
exception_status=e.code, exception_str=str(e)
|
||||
model_id="", exception_status=e.code, exception_str=str(e)
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
@ -2272,7 +2272,13 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
|
|||
"litellm_params": {
|
||||
"model": "openai/text-embedding-ada-002",
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "text-embedding-ada-002",
|
||||
"litellm_params": {
|
||||
"model": "openai/text-embedding-ada-002",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import openai
|
|||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.router import DeploymentTypedDict, LiteLLMParamsTypedDict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -112,3 +113,40 @@ async def test_dynamic_cooldowns():
|
|||
|
||||
assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"]
|
||||
assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_deployments", [1, 2])
|
||||
def test_single_deployment_no_cooldowns(num_deployments):
|
||||
"""
|
||||
Do not cooldown on single deployment.
|
||||
|
||||
Cooldown on multiple deployments.
|
||||
"""
|
||||
model_list = []
|
||||
for i in range(num_deployments):
|
||||
model = DeploymentTypedDict(
|
||||
model_name="gpt-3.5-turbo",
|
||||
litellm_params=LiteLLMParamsTypedDict(
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
)
|
||||
model_list.append(model)
|
||||
|
||||
router = Router(model_list=model_list, allowed_fails=0, num_retries=0)
|
||||
|
||||
with patch.object(
|
||||
router.cooldown_cache, "add_deployment_to_cooldown", new=MagicMock()
|
||||
) as mock_client:
|
||||
try:
|
||||
router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
mock_response="litellm.RateLimitError",
|
||||
)
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
|
||||
if num_deployments == 1:
|
||||
mock_client.assert_not_called()
|
||||
else:
|
||||
mock_client.assert_called_once()
|
||||
|
|
|
@ -89,6 +89,17 @@ async def test_router_retries_errors(sync_mode, error_type):
|
|||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
{
|
||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": _api_key,
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, allowed_fails=3)
|
||||
|
|
|
@ -17,6 +17,8 @@ import os
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
|
@ -459,3 +461,138 @@ async def test_router_completion_streaming():
|
|||
- Unit test for sync 'pre_call_checks'
|
||||
- Unit test for async 'async_pre_call_checks'
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_caching_ttl():
|
||||
"""
|
||||
Confirm caching ttl's work as expected.
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/5609
|
||||
"""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
|
||||
]
|
||||
model = "azure-model"
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"tpm": 1440,
|
||||
"mock_response": "Hello world",
|
||||
},
|
||||
"model_info": {"id": 1},
|
||||
}
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
set_verbose=False,
|
||||
redis_host=os.getenv("REDIS_HOST"),
|
||||
redis_password=os.getenv("REDIS_PASSWORD"),
|
||||
redis_port=os.getenv("REDIS_PORT"),
|
||||
)
|
||||
|
||||
assert router.cache.redis_cache is not None
|
||||
|
||||
increment_cache_kwargs = {}
|
||||
with patch.object(
|
||||
router.cache.redis_cache,
|
||||
"async_increment",
|
||||
new=AsyncMock(),
|
||||
) as mock_client:
|
||||
await router.acompletion(model=model, messages=messages)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
|
||||
print(f"mock_client.call_args.args: {mock_client.call_args.args}")
|
||||
|
||||
increment_cache_kwargs = {
|
||||
"key": mock_client.call_args.args[0],
|
||||
"value": mock_client.call_args.args[1],
|
||||
"ttl": mock_client.call_args.kwargs["ttl"],
|
||||
}
|
||||
|
||||
assert mock_client.call_args.kwargs["ttl"] == 60
|
||||
|
||||
## call redis async increment and check if ttl correctly set
|
||||
await router.cache.redis_cache.async_increment(**increment_cache_kwargs)
|
||||
|
||||
_redis_client = router.cache.redis_cache.init_async_client()
|
||||
|
||||
async with _redis_client as redis_client:
|
||||
current_ttl = await redis_client.ttl(increment_cache_kwargs["key"])
|
||||
|
||||
assert current_ttl >= 0
|
||||
|
||||
print(f"current_ttl: {current_ttl}")
|
||||
|
||||
|
||||
def test_router_caching_ttl_sync():
|
||||
"""
|
||||
Confirm caching ttl's work as expected.
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/5609
|
||||
"""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
|
||||
]
|
||||
model = "azure-model"
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"tpm": 1440,
|
||||
"mock_response": "Hello world",
|
||||
},
|
||||
"model_info": {"id": 1},
|
||||
}
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
set_verbose=False,
|
||||
redis_host=os.getenv("REDIS_HOST"),
|
||||
redis_password=os.getenv("REDIS_PASSWORD"),
|
||||
redis_port=os.getenv("REDIS_PORT"),
|
||||
)
|
||||
|
||||
assert router.cache.redis_cache is not None
|
||||
|
||||
increment_cache_kwargs = {}
|
||||
with patch.object(
|
||||
router.cache.redis_cache,
|
||||
"increment_cache",
|
||||
new=MagicMock(),
|
||||
) as mock_client:
|
||||
router.completion(model=model, messages=messages)
|
||||
|
||||
print(mock_client.call_args_list)
|
||||
mock_client.assert_called()
|
||||
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
|
||||
print(f"mock_client.call_args.args: {mock_client.call_args.args}")
|
||||
|
||||
increment_cache_kwargs = {
|
||||
"key": mock_client.call_args.args[0],
|
||||
"value": mock_client.call_args.args[1],
|
||||
"ttl": mock_client.call_args.kwargs["ttl"],
|
||||
}
|
||||
|
||||
assert mock_client.call_args.kwargs["ttl"] == 60
|
||||
|
||||
## call redis async increment and check if ttl correctly set
|
||||
router.cache.redis_cache.increment_cache(**increment_cache_kwargs)
|
||||
|
||||
_redis_client = router.cache.redis_cache.redis_client
|
||||
|
||||
current_ttl = _redis_client.ttl(increment_cache_kwargs["key"])
|
||||
|
||||
assert current_ttl >= 0
|
||||
|
||||
print(f"current_ttl: {current_ttl}")
|
||||
|
|
|
@ -4,7 +4,6 @@ import time
|
|||
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
||||
from traceloop.sdk import Traceloop
|
||||
|
||||
import litellm
|
||||
|
||||
|
@ -13,6 +12,8 @@ sys.path.insert(0, os.path.abspath("../.."))
|
|||
|
||||
@pytest.fixture()
|
||||
def exporter():
|
||||
from traceloop.sdk import Traceloop
|
||||
|
||||
exporter = InMemorySpanExporter()
|
||||
Traceloop.init(
|
||||
app_name="test_litellm",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue