LiteLLM Minor Fixes and Improvements (11/09/2024) (#5634)

* fix(caching.py): set ttl for async_increment cache

fixes issue where ttl for redis client was not being set on increment_cache

Fixes https://github.com/BerriAI/litellm/issues/5609

* fix(caching.py): fix increment cache w/ ttl for sync increment cache on redis

Fixes https://github.com/BerriAI/litellm/issues/5609

* fix(router.py): support adding retry policy + allowed fails policy via config.yaml

* fix(router.py): don't cooldown single deployments

No point, as there's no other deployment to loadbalance with.

* fix(user_api_key_auth.py): support setting allowed email domains on jwt tokens

Closes https://github.com/BerriAI/litellm/issues/5605

* docs(token_auth.md): add user upsert + allowed email domain to jwt auth docs

* fix(litellm_pre_call_utils.py): fix dynamic key logging when team id is set

Fixes issue where key logging would not be set if team metadata was not none

* fix(secret_managers/main.py): load environment variables correctly

Fixes issue where os.environ/ was not being loaded correctly

* test(test_router.py): fix test

* feat(spend_tracking_utils.py): support logging additional usage params - e.g. prompt caching values for deepseek

* test: fix tests

* test: fix test

* test: fix test

* test: fix test

* test: fix test
This commit is contained in:
Krish Dholakia 2024-09-11 22:36:06 -07:00 committed by GitHub
parent 70100d716b
commit 98c34a7e27
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 745 additions and 114 deletions

View file

@ -1,12 +1,12 @@
repos:
- repo: local
hooks:
- id: mypy
name: mypy
entry: python3 -m mypy --ignore-missing-imports
language: system
types: [python]
files: ^litellm/
# - id: mypy
# name: mypy
# entry: python3 -m mypy --ignore-missing-imports
# language: system
# types: [python]
# files: ^litellm/
- id: isort
name: isort
entry: isort

View file

@ -243,3 +243,17 @@ curl --location 'http://0.0.0.0:4000/team/unblock' \
}'
```
## Advanced - Upsert Users + Allowed Email Domains
Allow users who belong to a specific email domain, automatic access to the proxy.
```yaml
general_settings:
master_key: sk-1234
enable_jwt_auth: True
litellm_jwtauth:
user_email_jwt_field: "email" # 👈 checks 'email' field in jwt payload
user_allowed_email_domain: "my-co.com" # allows user@my-co.com to call proxy
user_id_upsert: true # 👈 upserts the user to db, if valid email but not in db
```

View file

@ -1038,6 +1038,12 @@ print(f"response: {response}")
- Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
- Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment
[**See All Exception Types**](https://github.com/BerriAI/litellm/blob/ccda616f2f881375d4e8586c76fe4662909a7d22/litellm/types/router.py#L436)
<Tabs>
<TabItem value="sdk" label="SDK">
Example:
```python
@ -1101,6 +1107,24 @@ response = await router.acompletion(
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```yaml
router_settings:
retry_policy: {
"BadRequestErrorRetries": 3,
"ContentPolicyViolationErrorRetries": 4
}
allowed_fails_policy: {
"ContentPolicyViolationErrorAllowedFails": 1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
"RateLimitErrorAllowedFails": 100 # Allow 100 RateLimitErrors before cooling down a deployment
}
```
</TabItem>
</Tabs>
### Fallbacks

View file

@ -304,40 +304,25 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
)
def increment_cache(self, key, value: int, **kwargs) -> int:
def increment_cache(
self, key, value: int, ttl: Optional[float] = None, **kwargs
) -> int:
_redis_client = self.redis_client
start_time = time.time()
try:
result = _redis_client.incr(name=key, amount=value)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.service_success_hook(
service=ServiceTypes.REDIS,
duration=_duration,
call_type="increment_cache",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
if ttl is not None:
# check if key already has ttl, if not -> set ttl
current_ttl = _redis_client.ttl(key)
if current_ttl == -1:
# Key has no expiration
_redis_client.expire(key, ttl)
return result
except Exception as e:
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
service=ServiceTypes.REDIS,
duration=_duration,
error=e,
call_type="increment_cache",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
verbose_logger.error(
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
str(e),
@ -606,12 +591,22 @@ class RedisCache(BaseCache):
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer() # logging done in here
async def async_increment(self, key, value: float, **kwargs) -> float:
async def async_increment(
self, key, value: float, ttl: Optional[float] = None, **kwargs
) -> float:
_redis_client = self.init_async_client()
start_time = time.time()
try:
async with _redis_client as redis_client:
result = await redis_client.incrbyfloat(name=key, amount=value)
if ttl is not None:
# check if key already has ttl, if not -> set ttl
current_ttl = await redis_client.ttl(key)
if current_ttl == -1:
# Key has no expiration
await redis_client.expire(key, ttl)
## LOGGING ##
end_time = time.time()
_duration = end_time - start_time

View file

@ -1609,15 +1609,24 @@ class Logging:
"""
from litellm.types.router import RouterErrors
litellm_params: dict = self.model_call_details.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
## BASE CASE ## check if rate limit error for model group size 1
is_base_case = False
if metadata.get("model_group_size") is not None:
model_group_size = metadata.get("model_group_size")
if isinstance(model_group_size, int) and model_group_size == 1:
is_base_case = True
## check if special error ##
if RouterErrors.no_deployments_available.value not in str(exception):
if (
RouterErrors.no_deployments_available.value not in str(exception)
and is_base_case is False
):
return
## get original model group ##
litellm_params: dict = self.model_call_details.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
model_group = metadata.get("model_group") or None
for callback in litellm._async_failure_callback:
if isinstance(callback, CustomLogger): # custom logger class

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -386,6 +386,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
- team_allowed_routes: list of allowed routes for proxy team roles.
- user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
- user_email_jwt_field: The field in the JWT token that stores the user email (maps to `LiteLLMUserTable`). Use this for internal employees.
- user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy.
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
@ -417,6 +419,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
)
org_id_jwt_field: Optional[str] = None
user_id_jwt_field: Optional[str] = None
user_email_jwt_field: Optional[str] = None
user_allowed_email_domain: Optional[str] = None
user_id_upsert: bool = Field(
default=False, description="If user doesn't exist, upsert them into the db."
)
@ -1690,6 +1694,9 @@ class SpendLogsMetadata(TypedDict):
Specific metadata k,v pairs logged to spendlogs for easier cost tracking
"""
additional_usage_values: Optional[
dict
] # covers provider-specific usage information - e.g. prompt caching
user_api_key: Optional[str]
user_api_key_alias: Optional[str]
user_api_key_team_id: Optional[str]

View file

@ -78,6 +78,19 @@ class JWTHandler:
return False
return True
def is_enforced_email_domain(self) -> bool:
"""
Returns:
- True: if 'user_allowed_email_domain' is set
- False: if 'user_allowed_email_domain' is None
"""
if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance(
self.litellm_jwtauth.user_allowed_email_domain, str
):
return True
return False
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.team_id_jwt_field is not None:
@ -90,12 +103,14 @@ class JWTHandler:
team_id = default_value
return team_id
def is_upsert_user_id(self) -> bool:
def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool:
"""
Returns:
- True: if 'user_id_upsert' is set
- True: if 'user_id_upsert' is set AND valid_user_email is not False
- False: if not
"""
if valid_user_email is False:
return False
return self.litellm_jwtauth.user_id_upsert
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
@ -103,11 +118,23 @@ class JWTHandler:
if self.litellm_jwtauth.user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.user_id_jwt_field]
else:
user_id = None
user_id = default_value
except KeyError:
user_id = default_value
return user_id
def get_user_email(
self, token: dict, default_value: Optional[str]
) -> Optional[str]:
try:
if self.litellm_jwtauth.user_email_jwt_field is not None:
user_email = token[self.litellm_jwtauth.user_email_jwt_field]
else:
user_email = None
except KeyError:
user_email = default_value
return user_email
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.org_id_jwt_field is not None:
@ -183,6 +210,16 @@ class JWTHandler:
return public_key
def is_allowed_domain(self, user_email: str) -> bool:
if self.litellm_jwtauth.user_allowed_email_domain is None:
return True
email_domain = user_email.split("@")[-1] # Extract domain from email
if email_domain == self.litellm_jwtauth.user_allowed_email_domain:
return True
else:
return False
async def auth_jwt(self, token: str) -> dict:
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret

View file

@ -250,6 +250,7 @@ async def user_api_key_auth(
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# get team id
team_id = jwt_handler.get_team_id(
token=jwt_valid_token, default_value=None
@ -296,10 +297,30 @@ async def user_api_key_auth(
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] allowed user email domains
valid_user_email: Optional[bool] = None
user_email: Optional[str] = None
if jwt_handler.is_enforced_email_domain():
"""
if 'allowed_email_subdomains' is set,
- checks if token contains 'email' field
- checks if 'email' is from an allowed domain
"""
user_email = jwt_handler.get_user_email(
token=jwt_valid_token, default_value=None
)
if user_email is None:
valid_user_email = False
else:
valid_user_email = jwt_handler.is_allowed_domain(
user_email=user_email
)
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None
user_id = jwt_handler.get_user_id(
token=jwt_valid_token, default_value=None
token=jwt_valid_token, default_value=user_email
)
if user_id is not None:
# get the user object
@ -307,11 +328,12 @@ async def user_api_key_auth(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=jwt_handler.is_upsert_user_id(),
user_id_upsert=jwt_handler.is_upsert_user_id(
valid_user_email=valid_user_email
),
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None
end_user_id = jwt_handler.get_end_user_id(
@ -802,7 +824,7 @@ async def user_api_key_auth(
# collect information for alerting #
####################################
user_email: Optional[str] = None
user_email = None
# Check if the token has any user id information
if user_obj is not None:
user_email = user_obj.user_email

View file

@ -107,7 +107,16 @@ def _get_dynamic_logging_metadata(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[TeamCallbackMetadata]:
callback_settings_obj: Optional[TeamCallbackMetadata] = None
if user_api_key_dict.team_metadata is not None:
if (
user_api_key_dict.metadata is not None
and "logging" in user_api_key_dict.metadata
):
for item in user_api_key_dict.metadata["logging"]:
callback_settings_obj = convert_key_logging_metadata_to_callback(
data=AddTeamCallback(**item),
team_callback_settings_obj=callback_settings_obj,
)
elif user_api_key_dict.team_metadata is not None:
team_metadata = user_api_key_dict.team_metadata
if "callback_settings" in team_metadata:
callback_settings = team_metadata.get("callback_settings", None) or {}
@ -124,15 +133,7 @@ def _get_dynamic_logging_metadata(
}
}
"""
elif (
user_api_key_dict.metadata is not None
and "logging" in user_api_key_dict.metadata
):
for item in user_api_key_dict.metadata["logging"]:
callback_settings_obj = convert_key_logging_metadata_to_callback(
data=AddTeamCallback(**item),
team_callback_settings_obj=callback_settings_obj,
)
return callback_settings_obj

View file

@ -84,6 +84,7 @@ def get_logging_payload(
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
additional_usage_values=None,
)
if isinstance(metadata, dict):
verbose_proxy_logger.debug(
@ -100,6 +101,13 @@ def get_logging_payload(
}
)
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
additional_usage_values = {}
for k, v in usage.items():
if k not in special_usage_fields:
additional_usage_values.update({k: v})
clean_metadata["additional_usage_values"] = additional_usage_values
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
else:

View file

@ -161,10 +161,10 @@ class Router:
enable_tag_filtering: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request
retry_policy: Optional[
RetryPolicy
Union[RetryPolicy, dict]
] = None, # set custom retries for different exceptions
model_group_retry_policy: Optional[
Dict[str, RetryPolicy]
model_group_retry_policy: Dict[
str, RetryPolicy
] = {}, # set custom retry policies based on model group
allowed_fails: Optional[
int
@ -263,7 +263,7 @@ class Router:
self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks
self.enable_tag_filtering = enable_tag_filtering
if self.set_verbose == True:
if self.set_verbose is True:
if debug_level == "INFO":
verbose_router_logger.setLevel(logging.INFO)
elif debug_level == "DEBUG":
@ -454,11 +454,35 @@ class Router:
)
self.routing_strategy_args = routing_strategy_args
self.retry_policy: Optional[RetryPolicy] = retry_policy
self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None:
if isinstance(retry_policy, dict):
self.retry_policy = RetryPolicy(**retry_policy)
elif isinstance(retry_policy, RetryPolicy):
self.retry_policy = retry_policy
verbose_router_logger.info(
"\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format(
self.retry_policy.model_dump(exclude_none=True)
)
)
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
if allowed_fails_policy is not None:
if isinstance(allowed_fails_policy, dict):
self.allowed_fails_policy = AllowedFailsPolicy(**allowed_fails_policy)
elif isinstance(allowed_fails_policy, AllowedFailsPolicy):
self.allowed_fails_policy = allowed_fails_policy
verbose_router_logger.info(
"\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format(
self.allowed_fails_policy.model_dump(exclude_none=True)
)
)
self.alerting_config: Optional[AlertingConfig] = alerting_config
if self.alerting_config is not None:
self._initialize_alerting()
@ -3003,6 +3027,13 @@ class Router:
model_group = kwargs.get("model")
num_retries = kwargs.pop("num_retries")
## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking
_metadata: dict = kwargs.get("metadata") or {}
if "model_group" in _metadata and isinstance(_metadata["model_group"], str):
model_list = self.get_model_list(model_name=_metadata["model_group"])
if model_list is not None:
_metadata.update({"model_group_size": len(model_list)})
verbose_router_logger.debug(
f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
)
@ -3165,6 +3196,7 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get(
@ -3173,6 +3205,7 @@ class Router:
content_policy_fallbacks = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
@ -3324,6 +3357,9 @@ class Router:
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
)
original_function = kwargs.pop("original_function")
mock_testing_rate_limit_error = kwargs.pop(
"mock_testing_rate_limit_error", None
)
num_retries = kwargs.pop("num_retries")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop(
@ -3332,9 +3368,22 @@ class Router:
content_policy_fallbacks = kwargs.pop(
"content_policy_fallbacks", self.content_policy_fallbacks
)
model_group = kwargs.get("model")
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
if (
mock_testing_rate_limit_error is not None
and mock_testing_rate_limit_error is True
):
verbose_router_logger.info(
"litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError."
)
raise litellm.RateLimitError(
model=model_group,
llm_provider="",
message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.",
)
response = original_function(*args, **kwargs)
return response
except Exception as e:
@ -3571,17 +3620,26 @@ class Router:
) # don't change existing ttl
def _is_cooldown_required(
self, exception_status: Union[str, int], exception_str: Optional[str] = None
):
self,
model_id: str,
exception_status: Union[str, int],
exception_str: Optional[str] = None,
) -> bool:
"""
A function to determine if a cooldown is required based on the exception status.
Parameters:
model_id (str) The id of the model in the model list
exception_status (Union[str, int]): The status of the exception.
Returns:
bool: True if a cooldown is required, False otherwise.
"""
## BASE CASE - single deployment
model_group = self.get_model_group(id=model_id)
if model_group is not None and len(model_group) == 1:
return False
try:
ignored_strings = ["APIConnectionError"]
if (
@ -3677,7 +3735,9 @@ class Router:
if (
self._is_cooldown_required(
exception_status=exception_status, exception_str=str(original_exception)
model_id=deployment,
exception_status=exception_status,
exception_str=str(original_exception),
)
is False
):
@ -3690,7 +3750,9 @@ class Router:
exception=original_exception,
)
allowed_fails = _allowed_fails if _allowed_fails is not None else self.allowed_fails
allowed_fails = (
_allowed_fails if _allowed_fails is not None else self.allowed_fails
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
@ -4298,6 +4360,18 @@ class Router:
return model
return None
def get_model_group(self, id: str) -> Optional[List]:
"""
Return list of all models in the same model group as that model id
"""
model_info = self.get_model_info(id=id)
if model_info is None:
return None
model_name = model_info["model_name"]
return self.get_model_list(model_name=model_name)
def _set_model_group_info(
self, model_group: str, user_facing_model_group_name: str
) -> Optional[ModelGroupInfo]:
@ -5528,18 +5602,19 @@ class Router:
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
retry_policy = self.retry_policy
retry_policy: Optional[RetryPolicy] = self.retry_policy
if (
self.model_group_retry_policy is not None
and model_group is not None
and model_group in self.model_group_retry_policy
):
retry_policy = self.model_group_retry_policy.get(model_group, None)
retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore
if retry_policy is None:
return None
if isinstance(retry_policy, dict):
retry_policy = RetryPolicy(**retry_policy)
if (
isinstance(exception, litellm.BadRequestError)
and retry_policy.BadRequestErrorRetries is not None

View file

@ -29,6 +29,27 @@ def _is_base64(s):
return False
def str_to_bool(value: str) -> Optional[bool]:
"""
Converts a string to a boolean if it's a recognized boolean string.
Returns None if the string is not a recognized boolean value.
:param value: The string to be checked.
:return: True or False if the string is a recognized boolean, otherwise None.
"""
true_values = {"true"}
false_values = {"false"}
value_lower = value.strip().lower()
if value_lower in true_values:
return True
elif value_lower in false_values:
return False
else:
return None
def get_secret(
secret_name: str,
default_value: Optional[Union[str, bool]] = None,
@ -257,17 +278,12 @@ def get_secret(
return secret
else:
secret = os.environ.get(secret_name)
try:
secret_value_as_bool = (
ast.literal_eval(secret) if secret is not None else None
)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except Exception:
if default_value is not None:
return default_value
secret_value_as_bool = str_to_bool(secret) if secret is not None else None
if secret_value_as_bool is not None and isinstance(
secret_value_as_bool, bool
):
return secret_value_as_bool
else:
return secret
except Exception as e:
if default_value is not None:

View file

@ -54,6 +54,7 @@ VERTEX_MODELS_TO_NOT_TEST = [
"gemini-flash-experimental",
"gemini-1.5-flash-exp-0827",
"gemini-pro-flash",
"gemini-1.5-flash-exp-0827",
]

View file

@ -38,6 +38,8 @@ from litellm.integrations.custom_logger import CustomLogger
## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings
litellm.num_retries = 0
class CompletionCustomHandler(
CustomLogger
@ -401,7 +403,7 @@ async def test_async_chat_azure():
"rpm": 1800,
},
]
router = Router(model_list=model_list) # type: ignore
router = Router(model_list=model_list, num_retries=0) # type: ignore
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
@ -413,7 +415,7 @@ async def test_async_chat_azure():
) # pre, post, success
# streaming
litellm.callbacks = [customHandler_streaming_azure_router]
router2 = Router(model_list=model_list) # type: ignore
router2 = Router(model_list=model_list, num_retries=0) # type: ignore
response = await router2.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
@ -443,7 +445,7 @@ async def test_async_chat_azure():
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
try:
response = await router3.acompletion(
model="gpt-3.5-turbo",
@ -505,7 +507,7 @@ async def test_async_embedding_azure():
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
router3 = Router(model_list=model_list, num_retries=0) # type: ignore
try:
response = await router3.aembedding(
model="azure-embedding-model", input=["hello from litellm!"]
@ -678,22 +680,21 @@ async def test_rate_limit_error_callback():
pass
with patch.object(
customHandler, "log_model_group_rate_limit_error", new=MagicMock()
customHandler, "log_model_group_rate_limit_error", new=AsyncMock()
) as mock_client:
print(
f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}"
)
for _ in range(3):
try:
_ = await router.acompletion(
model="my-test-gpt",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
litellm_logging_obj=litellm_logging_obj,
)
except (litellm.RateLimitError, ValueError):
pass
try:
_ = await router.acompletion(
model="my-test-gpt",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
litellm_logging_obj=litellm_logging_obj,
)
except (litellm.RateLimitError, ValueError):
pass
await asyncio.sleep(3)
mock_client.assert_called_once()

View file

@ -23,8 +23,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.team_endpoints import new_team
from litellm.proxy.proxy_server import chat_completion
@ -816,8 +817,6 @@ async def test_allowed_routes_admin(prisma_client, audience):
raise e
from unittest.mock import AsyncMock
import pytest
@ -844,3 +843,148 @@ async def test_team_cache_update_called():
await asyncio.sleep(3)
mock_call_cache.assert_awaited_once()
@pytest.fixture
def public_jwt_key():
import json
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
# Generate a private / public key pair using RSA algorithm
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
# Get private key in PEM format
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Get public key in PEM format
public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_obj = serialization.load_pem_public_key(
public_key, backend=default_backend()
)
# Convert RSA public key object to JWK (JSON Web Key)
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
return {"private_key": private_key, "public_jwk": public_jwk}
def mock_user_object(*args, **kwargs):
print("Args: {}".format(args))
print("kwargs: {}".format(kwargs))
assert kwargs["user_id_upsert"] is True
@pytest.mark.parametrize(
"user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)]
)
@pytest.mark.asyncio
async def test_allow_access_by_email(public_jwt_key, user_email, should_work):
"""
Allow anyone with an `@xyz.com` email make a request to the proxy.
Relevant issue: https://github.com/BerriAI/litellm/issues/5605
"""
import jwt
from starlette.datastructures import URL
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
from litellm.proxy.proxy_server import user_api_key_auth
public_jwk = public_jwt_key["public_jwk"]
private_key = public_jwt_key["private_key"]
# set cache
cache = DualCache()
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
jwt_handler = JWTHandler()
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
user_email_jwt_field="email",
user_allowed_email_domain="berri.ai",
user_id_upsert=True,
)
# VALID TOKEN
## GENERATE A TOKEN
# Assuming the current time is in UTC
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
team_id = f"team123_{uuid.uuid4()}"
payload = {
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": "litellm-proxy",
"email": user_email,
}
# Generate the JWT token
# But before, you should convert bytes to string
private_key_str = private_key.decode("utf-8")
## team token
token = jwt.encode(payload, private_key_str, algorithm="RS256")
## VERIFY IT WORKS
# Expect the call to succeed
response = await jwt_handler.auth_jwt(token=token)
assert response is not None # Adjust this based on your actual response check
## RUN IT THROUGH USER API KEY AUTH
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
setattr(litellm.proxy.proxy_server, "prisma_client", {})
# AsyncMock(
# return_value=LiteLLM_UserTable(
# spend=0, user_id=user_email, max_budget=None, user_email=user_email
# )
# ),
with patch.object(
litellm.proxy.auth.user_api_key_auth,
"get_user_object",
side_effect=mock_user_object,
) as mock_client:
if should_work:
# Expect the call to succeed
result = await user_api_key_auth(request=request, api_key=bearer_token)
assert result is not None # Adjust this based on your actual response check
else:
# Expect the call to fail
with pytest.raises(
Exception
): # Replace with the actual exception raised on failure
resp = await user_api_key_auth(request=request, api_key=bearer_token)
print(resp)

View file

@ -1,16 +1,17 @@
import sys
import os
import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion
import litellm
from litellm import completion
litellm.failure_callback = ["lunary"]
litellm.success_callback = ["lunary"]
litellm.set_verbose = True
def test_lunary_logging():
try:
response = completion(
@ -24,6 +25,7 @@ def test_lunary_logging():
except Exception as e:
print(e)
test_lunary_logging()
@ -37,8 +39,6 @@ def test_lunary_template():
except Exception as e:
print(e)
test_lunary_template()
def test_lunary_logging_with_metadata():
try:
@ -50,19 +50,23 @@ def test_lunary_logging_with_metadata():
metadata={
"run_name": "litellmRUN",
"project_name": "litellm-completion",
"tags": ["tag1", "tag2"]
"tags": ["tag1", "tag2"],
},
)
print(response)
except Exception as e:
print(e)
test_lunary_logging_with_metadata()
def test_lunary_with_tools():
import litellm
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
}
]
tools = [
{
"type": "function",
@ -90,13 +94,11 @@ def test_lunary_with_tools():
tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit
)
response_message = response.choices[0].message
print("\nLLM Response:\n", response.choices[0].message)
test_lunary_with_tools()
def test_lunary_logging_with_streaming_and_metadata():
try:
response = completion(
@ -114,5 +116,3 @@ def test_lunary_logging_with_streaming_and_metadata():
continue
except Exception as e:
print(e)
test_lunary_logging_with_streaming_and_metadata()

View file

@ -11,8 +11,11 @@ import litellm
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.litellm_pre_call_utils import (
_get_dynamic_logging_metadata,
add_litellm_data_to_request,
)
from litellm.types.utils import SupportedCacheControls
@ -204,3 +207,87 @@ async def test_add_key_or_team_level_spend_logs_metadata_to_request(
# assert (
# new_data["metadata"]["spend_logs_metadata"] == metadata["spend_logs_metadata"]
# )
@pytest.mark.parametrize(
"callback_vars",
[
{
"langfuse_host": "https://us.cloud.langfuse.com",
"langfuse_public_key": "pk-lf-9636b7a6-c066",
"langfuse_secret_key": "sk-lf-7cc8b620",
},
{
"langfuse_host": "os.environ/LANGFUSE_HOST_TEMP",
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY_TEMP",
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY_TEMP",
},
],
)
def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
os.environ["LANGFUSE_PUBLIC_KEY_TEMP"] = "pk-lf-9636b7a6-c066"
os.environ["LANGFUSE_SECRET_KEY_TEMP"] = "sk-lf-7cc8b620"
os.environ["LANGFUSE_HOST_TEMP"] = "https://us.cloud.langfuse.com"
user_api_key_dict = UserAPIKeyAuth(
token="6f8688eaff1d37555bb9e9a6390b6d7032b3ab2526ba0152da87128eab956432",
key_name="sk-...63Fg",
key_alias=None,
spend=0.000111,
max_budget=None,
expires=None,
models=[],
aliases={},
config={},
user_id=None,
team_id="ishaan-special-team_e02dd54f-f790-4755-9f93-73734f415898",
max_parallel_requests=None,
metadata={
"logging": [
{
"callback_name": "langfuse",
"callback_type": "success",
"callback_vars": callback_vars,
}
]
},
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
budget_reset_at=None,
allowed_cache_controls=[],
permissions={},
model_spend={},
model_max_budget={},
soft_budget_cooldown=False,
litellm_budget_table=None,
org_id=None,
team_spend=0.000132,
team_alias=None,
team_tpm_limit=None,
team_rpm_limit=None,
team_max_budget=None,
team_models=[],
team_blocked=False,
soft_budget=None,
team_model_aliases=None,
team_member_spend=None,
team_member=None,
team_metadata={},
end_user_id=None,
end_user_tpm_limit=None,
end_user_rpm_limit=None,
end_user_max_budget=None,
last_refreshed_at=1726101560.967527,
api_key="7c305cc48fe72272700dc0d67dc691c2d1f2807490ef5eb2ee1d3a3ca86e12b1",
user_role=LitellmUserRoles.INTERNAL_USER,
allowed_model_region=None,
parent_otel_span=None,
rpm_limit_per_model=None,
tpm_limit_per_model=None,
)
callbacks = _get_dynamic_logging_metadata(user_api_key_dict=user_api_key_dict)
assert callbacks is not None
for var in callbacks.callback_vars.values():
assert "os.environ" not in var

View file

@ -2121,7 +2121,7 @@ def test_router_cooldown_api_connection_error():
except litellm.APIConnectionError as e:
assert (
Router()._is_cooldown_required(
exception_status=e.code, exception_str=str(e)
model_id="", exception_status=e.code, exception_str=str(e)
)
is False
)
@ -2272,7 +2272,13 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
"litellm_params": {
"model": "openai/text-embedding-ada-002",
},
}
},
{
"model_name": "text-embedding-ada-002",
"litellm_params": {
"model": "openai/text-embedding-ada-002",
},
},
]
)

View file

@ -21,6 +21,7 @@ import openai
import litellm
from litellm import Router
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.router import DeploymentTypedDict, LiteLLMParamsTypedDict
@pytest.mark.asyncio
@ -112,3 +113,40 @@ async def test_dynamic_cooldowns():
assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"]
assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0
@pytest.mark.parametrize("num_deployments", [1, 2])
def test_single_deployment_no_cooldowns(num_deployments):
"""
Do not cooldown on single deployment.
Cooldown on multiple deployments.
"""
model_list = []
for i in range(num_deployments):
model = DeploymentTypedDict(
model_name="gpt-3.5-turbo",
litellm_params=LiteLLMParamsTypedDict(
model="gpt-3.5-turbo",
),
)
model_list.append(model)
router = Router(model_list=model_list, allowed_fails=0, num_retries=0)
with patch.object(
router.cooldown_cache, "add_deployment_to_cooldown", new=MagicMock()
) as mock_client:
try:
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
if num_deployments == 1:
mock_client.assert_not_called()
else:
mock_client.assert_called_once()

View file

@ -89,6 +89,17 @@ async def test_router_retries_errors(sync_mode, error_type):
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": _api_key,
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
]
router = Router(model_list=model_list, allowed_fails=3)

View file

@ -17,6 +17,8 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import litellm
@ -459,3 +461,138 @@ async def test_router_completion_streaming():
- Unit test for sync 'pre_call_checks'
- Unit test for async 'async_pre_call_checks'
"""
@pytest.mark.asyncio
async def test_router_caching_ttl():
"""
Confirm caching ttl's work as expected.
Relevant issue: https://github.com/BerriAI/litellm/issues/5609
"""
messages = [
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
]
model = "azure-model"
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"tpm": 1440,
"mock_response": "Hello world",
},
"model_info": {"id": 1},
}
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing-v2",
set_verbose=False,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=os.getenv("REDIS_PORT"),
)
assert router.cache.redis_cache is not None
increment_cache_kwargs = {}
with patch.object(
router.cache.redis_cache,
"async_increment",
new=AsyncMock(),
) as mock_client:
await router.acompletion(model=model, messages=messages)
mock_client.assert_called_once()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
print(f"mock_client.call_args.args: {mock_client.call_args.args}")
increment_cache_kwargs = {
"key": mock_client.call_args.args[0],
"value": mock_client.call_args.args[1],
"ttl": mock_client.call_args.kwargs["ttl"],
}
assert mock_client.call_args.kwargs["ttl"] == 60
## call redis async increment and check if ttl correctly set
await router.cache.redis_cache.async_increment(**increment_cache_kwargs)
_redis_client = router.cache.redis_cache.init_async_client()
async with _redis_client as redis_client:
current_ttl = await redis_client.ttl(increment_cache_kwargs["key"])
assert current_ttl >= 0
print(f"current_ttl: {current_ttl}")
def test_router_caching_ttl_sync():
"""
Confirm caching ttl's work as expected.
Relevant issue: https://github.com/BerriAI/litellm/issues/5609
"""
messages = [
{"role": "user", "content": "Hello, can you generate a 500 words poem?"}
]
model = "azure-model"
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"tpm": 1440,
"mock_response": "Hello world",
},
"model_info": {"id": 1},
}
]
router = Router(
model_list=model_list,
routing_strategy="usage-based-routing-v2",
set_verbose=False,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=os.getenv("REDIS_PORT"),
)
assert router.cache.redis_cache is not None
increment_cache_kwargs = {}
with patch.object(
router.cache.redis_cache,
"increment_cache",
new=MagicMock(),
) as mock_client:
router.completion(model=model, messages=messages)
print(mock_client.call_args_list)
mock_client.assert_called()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
print(f"mock_client.call_args.args: {mock_client.call_args.args}")
increment_cache_kwargs = {
"key": mock_client.call_args.args[0],
"value": mock_client.call_args.args[1],
"ttl": mock_client.call_args.kwargs["ttl"],
}
assert mock_client.call_args.kwargs["ttl"] == 60
## call redis async increment and check if ttl correctly set
router.cache.redis_cache.increment_cache(**increment_cache_kwargs)
_redis_client = router.cache.redis_cache.redis_client
current_ttl = _redis_client.ttl(increment_cache_kwargs["key"])
assert current_ttl >= 0
print(f"current_ttl: {current_ttl}")

View file

@ -4,7 +4,6 @@ import time
import pytest
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from traceloop.sdk import Traceloop
import litellm
@ -13,6 +12,8 @@ sys.path.insert(0, os.path.abspath("../.."))
@pytest.fixture()
def exporter():
from traceloop.sdk import Traceloop
exporter = InMemorySpanExporter()
Traceloop.init(
app_name="test_litellm",