diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a33473b72..d429bc6b8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: local hooks: - - id: mypy - name: mypy - entry: python3 -m mypy --ignore-missing-imports - language: system - types: [python] - files: ^litellm/ + # - id: mypy + # name: mypy + # entry: python3 -m mypy --ignore-missing-imports + # language: system + # types: [python] + # files: ^litellm/ - id: isort name: isort entry: isort diff --git a/docs/my-website/docs/proxy/token_auth.md b/docs/my-website/docs/proxy/token_auth.md index 659cc6edf..049ea0f98 100644 --- a/docs/my-website/docs/proxy/token_auth.md +++ b/docs/my-website/docs/proxy/token_auth.md @@ -243,3 +243,17 @@ curl --location 'http://0.0.0.0:4000/team/unblock' \ }' ``` + +## Advanced - Upsert Users + Allowed Email Domains + +Allow users who belong to a specific email domain, automatic access to the proxy. + +```yaml +general_settings: + master_key: sk-1234 + enable_jwt_auth: True + litellm_jwtauth: + user_email_jwt_field: "email" # 👈 checks 'email' field in jwt payload + user_allowed_email_domain: "my-co.com" # allows user@my-co.com to call proxy + user_id_upsert: true # 👈 upserts the user to db, if valid email but not in db +``` \ No newline at end of file diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index c7c6c3c97..87925516a 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -1038,6 +1038,12 @@ print(f"response: {response}") - Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved - Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment +[**See All Exception Types**](https://github.com/BerriAI/litellm/blob/ccda616f2f881375d4e8586c76fe4662909a7d22/litellm/types/router.py#L436) + + + + + Example: ```python @@ -1101,6 +1107,24 @@ response = await router.acompletion( ) ``` + + + +```yaml +router_settings: + retry_policy: { + "BadRequestErrorRetries": 3, + "ContentPolicyViolationErrorRetries": 4 + } + allowed_fails_policy: { + "ContentPolicyViolationErrorAllowedFails": 1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment + "RateLimitErrorAllowedFails": 100 # Allow 100 RateLimitErrors before cooling down a deployment + } +``` + + + + ### Fallbacks diff --git a/litellm/caching.py b/litellm/caching.py index 5add0cd8e..0a9fef417 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -304,40 +304,25 @@ class RedisCache(BaseCache): f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}" ) - def increment_cache(self, key, value: int, **kwargs) -> int: + def increment_cache( + self, key, value: int, ttl: Optional[float] = None, **kwargs + ) -> int: _redis_client = self.redis_client start_time = time.time() try: result = _redis_client.incr(name=key, amount=value) - ## LOGGING ## - end_time = time.time() - _duration = end_time - start_time - asyncio.create_task( - self.service_logger_obj.service_success_hook( - service=ServiceTypes.REDIS, - duration=_duration, - call_type="increment_cache", - start_time=start_time, - end_time=end_time, - parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), - ) - ) + + if ttl is not None: + # check if key already has ttl, if not -> set ttl + current_ttl = _redis_client.ttl(key) + if current_ttl == -1: + # Key has no expiration + _redis_client.expire(key, ttl) return result except Exception as e: ## LOGGING ## end_time = time.time() _duration = end_time - start_time - asyncio.create_task( - self.service_logger_obj.async_service_failure_hook( - service=ServiceTypes.REDIS, - duration=_duration, - error=e, - call_type="increment_cache", - start_time=start_time, - end_time=end_time, - parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), - ) - ) verbose_logger.error( "LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s", str(e), @@ -606,12 +591,22 @@ class RedisCache(BaseCache): if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: await self.flush_cache_buffer() # logging done in here - async def async_increment(self, key, value: float, **kwargs) -> float: + async def async_increment( + self, key, value: float, ttl: Optional[float] = None, **kwargs + ) -> float: _redis_client = self.init_async_client() start_time = time.time() try: async with _redis_client as redis_client: result = await redis_client.incrbyfloat(name=key, amount=value) + + if ttl is not None: + # check if key already has ttl, if not -> set ttl + current_ttl = await redis_client.ttl(key) + if current_ttl == -1: + # Key has no expiration + await redis_client.expire(key, ttl) + ## LOGGING ## end_time = time.time() _duration = end_time - start_time diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 0d3e59db5..9528b6fbb 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1609,15 +1609,24 @@ class Logging: """ from litellm.types.router import RouterErrors + litellm_params: dict = self.model_call_details.get("litellm_params") or {} + metadata = litellm_params.get("metadata") or {} + + ## BASE CASE ## check if rate limit error for model group size 1 + is_base_case = False + if metadata.get("model_group_size") is not None: + model_group_size = metadata.get("model_group_size") + if isinstance(model_group_size, int) and model_group_size == 1: + is_base_case = True ## check if special error ## - if RouterErrors.no_deployments_available.value not in str(exception): + if ( + RouterErrors.no_deployments_available.value not in str(exception) + and is_base_case is False + ): return ## get original model group ## - litellm_params: dict = self.model_call_details.get("litellm_params") or {} - metadata = litellm_params.get("metadata") or {} - model_group = metadata.get("model_group") or None for callback in litellm._async_failure_callback: if isinstance(callback, CustomLogger): # custom logger class diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html deleted file mode 100644 index 34d1b613d..000000000 --- a/litellm/proxy/_experimental/out/404.html +++ /dev/null @@ -1 +0,0 @@ -404: This page could not be found.LiteLLM Dashboard

404

This page could not be found.

\ No newline at end of file diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html deleted file mode 100644 index 07e68f30e..000000000 --- a/litellm/proxy/_experimental/out/model_hub.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index abb658918..000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index da98b6a10..662d6d835 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -386,6 +386,8 @@ class LiteLLM_JWTAuth(LiteLLMBase): - team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`. - team_allowed_routes: list of allowed routes for proxy team roles. - user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees. + - user_email_jwt_field: The field in the JWT token that stores the user email (maps to `LiteLLMUserTable`). Use this for internal employees. + - user_allowed_email_subdomain: If specified, only emails from specified subdomain will be allowed to access proxy. - end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers. - public_key_ttl: Default - 600s. TTL for caching public JWT keys. @@ -417,6 +419,8 @@ class LiteLLM_JWTAuth(LiteLLMBase): ) org_id_jwt_field: Optional[str] = None user_id_jwt_field: Optional[str] = None + user_email_jwt_field: Optional[str] = None + user_allowed_email_domain: Optional[str] = None user_id_upsert: bool = Field( default=False, description="If user doesn't exist, upsert them into the db." ) @@ -1690,6 +1694,9 @@ class SpendLogsMetadata(TypedDict): Specific metadata k,v pairs logged to spendlogs for easier cost tracking """ + additional_usage_values: Optional[ + dict + ] # covers provider-specific usage information - e.g. prompt caching user_api_key: Optional[str] user_api_key_alias: Optional[str] user_api_key_team_id: Optional[str] diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index f8618781f..b39064ae6 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -78,6 +78,19 @@ class JWTHandler: return False return True + def is_enforced_email_domain(self) -> bool: + """ + Returns: + - True: if 'user_allowed_email_domain' is set + - False: if 'user_allowed_email_domain' is None + """ + + if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance( + self.litellm_jwtauth.user_allowed_email_domain, str + ): + return True + return False + def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: if self.litellm_jwtauth.team_id_jwt_field is not None: @@ -90,12 +103,14 @@ class JWTHandler: team_id = default_value return team_id - def is_upsert_user_id(self) -> bool: + def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool: """ Returns: - - True: if 'user_id_upsert' is set + - True: if 'user_id_upsert' is set AND valid_user_email is not False - False: if not """ + if valid_user_email is False: + return False return self.litellm_jwtauth.user_id_upsert def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: @@ -103,11 +118,23 @@ class JWTHandler: if self.litellm_jwtauth.user_id_jwt_field is not None: user_id = token[self.litellm_jwtauth.user_id_jwt_field] else: - user_id = None + user_id = default_value except KeyError: user_id = default_value return user_id + def get_user_email( + self, token: dict, default_value: Optional[str] + ) -> Optional[str]: + try: + if self.litellm_jwtauth.user_email_jwt_field is not None: + user_email = token[self.litellm_jwtauth.user_email_jwt_field] + else: + user_email = None + except KeyError: + user_email = default_value + return user_email + def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: if self.litellm_jwtauth.org_id_jwt_field is not None: @@ -183,6 +210,16 @@ class JWTHandler: return public_key + def is_allowed_domain(self, user_email: str) -> bool: + if self.litellm_jwtauth.user_allowed_email_domain is None: + return True + + email_domain = user_email.split("@")[-1] # Extract domain from email + if email_domain == self.litellm_jwtauth.user_allowed_email_domain: + return True + else: + return False + async def auth_jwt(self, token: str) -> dict: # Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html # "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index deee81ffd..114f27d44 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -250,6 +250,7 @@ async def user_api_key_auth( raise Exception( f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" ) + # get team id team_id = jwt_handler.get_team_id( token=jwt_valid_token, default_value=None @@ -296,10 +297,30 @@ async def user_api_key_auth( parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) + # [OPTIONAL] allowed user email domains + valid_user_email: Optional[bool] = None + user_email: Optional[str] = None + if jwt_handler.is_enforced_email_domain(): + """ + if 'allowed_email_subdomains' is set, + + - checks if token contains 'email' field + - checks if 'email' is from an allowed domain + """ + user_email = jwt_handler.get_user_email( + token=jwt_valid_token, default_value=None + ) + if user_email is None: + valid_user_email = False + else: + valid_user_email = jwt_handler.is_allowed_domain( + user_email=user_email + ) + # [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable` user_object = None user_id = jwt_handler.get_user_id( - token=jwt_valid_token, default_value=None + token=jwt_valid_token, default_value=user_email ) if user_id is not None: # get the user object @@ -307,11 +328,12 @@ async def user_api_key_auth( user_id=user_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, - user_id_upsert=jwt_handler.is_upsert_user_id(), + user_id_upsert=jwt_handler.is_upsert_user_id( + valid_user_email=valid_user_email + ), parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) - # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` end_user_object = None end_user_id = jwt_handler.get_end_user_id( @@ -802,7 +824,7 @@ async def user_api_key_auth( # collect information for alerting # #################################### - user_email: Optional[str] = None + user_email = None # Check if the token has any user id information if user_obj is not None: user_email = user_obj.user_email diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 890c576c9..4c6172a4d 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -107,7 +107,16 @@ def _get_dynamic_logging_metadata( user_api_key_dict: UserAPIKeyAuth, ) -> Optional[TeamCallbackMetadata]: callback_settings_obj: Optional[TeamCallbackMetadata] = None - if user_api_key_dict.team_metadata is not None: + if ( + user_api_key_dict.metadata is not None + and "logging" in user_api_key_dict.metadata + ): + for item in user_api_key_dict.metadata["logging"]: + callback_settings_obj = convert_key_logging_metadata_to_callback( + data=AddTeamCallback(**item), + team_callback_settings_obj=callback_settings_obj, + ) + elif user_api_key_dict.team_metadata is not None: team_metadata = user_api_key_dict.team_metadata if "callback_settings" in team_metadata: callback_settings = team_metadata.get("callback_settings", None) or {} @@ -124,15 +133,7 @@ def _get_dynamic_logging_metadata( } } """ - elif ( - user_api_key_dict.metadata is not None - and "logging" in user_api_key_dict.metadata - ): - for item in user_api_key_dict.metadata["logging"]: - callback_settings_obj = convert_key_logging_metadata_to_callback( - data=AddTeamCallback(**item), - team_callback_settings_obj=callback_settings_obj, - ) + return callback_settings_obj diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index a1a0b9733..bdeef92cc 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -84,6 +84,7 @@ def get_logging_payload( user_api_key_team_alias=None, spend_logs_metadata=None, requester_ip_address=None, + additional_usage_values=None, ) if isinstance(metadata, dict): verbose_proxy_logger.debug( @@ -100,6 +101,13 @@ def get_logging_payload( } ) + special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"] + additional_usage_values = {} + for k, v in usage.items(): + if k not in special_usage_fields: + additional_usage_values.update({k: v}) + clean_metadata["additional_usage_values"] = additional_usage_values + if litellm.cache is not None: cache_key = litellm.cache.get_cache_key(**kwargs) else: diff --git a/litellm/router.py b/litellm/router.py index 5a01f4f39..c187474f1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -161,10 +161,10 @@ class Router: enable_tag_filtering: bool = False, retry_after: int = 0, # min time to wait before retrying a failed request retry_policy: Optional[ - RetryPolicy + Union[RetryPolicy, dict] ] = None, # set custom retries for different exceptions - model_group_retry_policy: Optional[ - Dict[str, RetryPolicy] + model_group_retry_policy: Dict[ + str, RetryPolicy ] = {}, # set custom retry policies based on model group allowed_fails: Optional[ int @@ -263,7 +263,7 @@ class Router: self.debug_level = debug_level self.enable_pre_call_checks = enable_pre_call_checks self.enable_tag_filtering = enable_tag_filtering - if self.set_verbose == True: + if self.set_verbose is True: if debug_level == "INFO": verbose_router_logger.setLevel(logging.INFO) elif debug_level == "DEBUG": @@ -454,11 +454,35 @@ class Router: ) self.routing_strategy_args = routing_strategy_args - self.retry_policy: Optional[RetryPolicy] = retry_policy + self.retry_policy: Optional[RetryPolicy] = None + if retry_policy is not None: + if isinstance(retry_policy, dict): + self.retry_policy = RetryPolicy(**retry_policy) + elif isinstance(retry_policy, RetryPolicy): + self.retry_policy = retry_policy + verbose_router_logger.info( + "\033[32mRouter Custom Retry Policy Set:\n{}\033[0m".format( + self.retry_policy.model_dump(exclude_none=True) + ) + ) + self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( model_group_retry_policy ) - self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy + + self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None + if allowed_fails_policy is not None: + if isinstance(allowed_fails_policy, dict): + self.allowed_fails_policy = AllowedFailsPolicy(**allowed_fails_policy) + elif isinstance(allowed_fails_policy, AllowedFailsPolicy): + self.allowed_fails_policy = allowed_fails_policy + + verbose_router_logger.info( + "\033[32mRouter Custom Allowed Fails Policy Set:\n{}\033[0m".format( + self.allowed_fails_policy.model_dump(exclude_none=True) + ) + ) + self.alerting_config: Optional[AlertingConfig] = alerting_config if self.alerting_config is not None: self._initialize_alerting() @@ -3003,6 +3027,13 @@ class Router: model_group = kwargs.get("model") num_retries = kwargs.pop("num_retries") + ## ADD MODEL GROUP SIZE TO METADATA - used for model_group_rate_limit_error tracking + _metadata: dict = kwargs.get("metadata") or {} + if "model_group" in _metadata and isinstance(_metadata["model_group"], str): + model_list = self.get_model_list(model_name=_metadata["model_group"]) + if model_list is not None: + _metadata.update({"model_group_size": len(model_list)}) + verbose_router_logger.debug( f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}" ) @@ -3165,6 +3196,7 @@ class Router: If it fails after num_retries, fall back to another model group """ mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) + model_group = kwargs.get("model") fallbacks = kwargs.get("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.get( @@ -3173,6 +3205,7 @@ class Router: content_policy_fallbacks = kwargs.get( "content_policy_fallbacks", self.content_policy_fallbacks ) + try: if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: raise Exception( @@ -3324,6 +3357,9 @@ class Router: f"Inside function with retries: args - {args}; kwargs - {kwargs}" ) original_function = kwargs.pop("original_function") + mock_testing_rate_limit_error = kwargs.pop( + "mock_testing_rate_limit_error", None + ) num_retries = kwargs.pop("num_retries") fallbacks = kwargs.pop("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.pop( @@ -3332,9 +3368,22 @@ class Router: content_policy_fallbacks = kwargs.pop( "content_policy_fallbacks", self.content_policy_fallbacks ) + model_group = kwargs.get("model") try: # if the function call is successful, no exception will be raised and we'll break out of the loop + if ( + mock_testing_rate_limit_error is not None + and mock_testing_rate_limit_error is True + ): + verbose_router_logger.info( + "litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError." + ) + raise litellm.RateLimitError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.", + ) response = original_function(*args, **kwargs) return response except Exception as e: @@ -3571,17 +3620,26 @@ class Router: ) # don't change existing ttl def _is_cooldown_required( - self, exception_status: Union[str, int], exception_str: Optional[str] = None - ): + self, + model_id: str, + exception_status: Union[str, int], + exception_str: Optional[str] = None, + ) -> bool: """ A function to determine if a cooldown is required based on the exception status. Parameters: + model_id (str) The id of the model in the model list exception_status (Union[str, int]): The status of the exception. Returns: bool: True if a cooldown is required, False otherwise. """ + ## BASE CASE - single deployment + model_group = self.get_model_group(id=model_id) + if model_group is not None and len(model_group) == 1: + return False + try: ignored_strings = ["APIConnectionError"] if ( @@ -3677,7 +3735,9 @@ class Router: if ( self._is_cooldown_required( - exception_status=exception_status, exception_str=str(original_exception) + model_id=deployment, + exception_status=exception_status, + exception_str=str(original_exception), ) is False ): @@ -3690,7 +3750,9 @@ class Router: exception=original_exception, ) - allowed_fails = _allowed_fails if _allowed_fails is not None else self.allowed_fails + allowed_fails = ( + _allowed_fails if _allowed_fails is not None else self.allowed_fails + ) dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") @@ -4298,6 +4360,18 @@ class Router: return model return None + def get_model_group(self, id: str) -> Optional[List]: + """ + Return list of all models in the same model group as that model id + """ + + model_info = self.get_model_info(id=id) + if model_info is None: + return None + + model_name = model_info["model_name"] + return self.get_model_list(model_name=model_name) + def _set_model_group_info( self, model_group: str, user_facing_model_group_name: str ) -> Optional[ModelGroupInfo]: @@ -5528,18 +5602,19 @@ class Router: ContentPolicyViolationErrorRetries: Optional[int] = None """ # if we can find the exception then in the retry policy -> return the number of retries - retry_policy = self.retry_policy + retry_policy: Optional[RetryPolicy] = self.retry_policy if ( self.model_group_retry_policy is not None and model_group is not None and model_group in self.model_group_retry_policy ): - retry_policy = self.model_group_retry_policy.get(model_group, None) + retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore if retry_policy is None: return None if isinstance(retry_policy, dict): retry_policy = RetryPolicy(**retry_policy) + if ( isinstance(exception, litellm.BadRequestError) and retry_policy.BadRequestErrorRetries is not None diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py index e136654c1..5d1f72cf7 100644 --- a/litellm/secret_managers/main.py +++ b/litellm/secret_managers/main.py @@ -29,6 +29,27 @@ def _is_base64(s): return False +def str_to_bool(value: str) -> Optional[bool]: + """ + Converts a string to a boolean if it's a recognized boolean string. + Returns None if the string is not a recognized boolean value. + + :param value: The string to be checked. + :return: True or False if the string is a recognized boolean, otherwise None. + """ + true_values = {"true"} + false_values = {"false"} + + value_lower = value.strip().lower() + + if value_lower in true_values: + return True + elif value_lower in false_values: + return False + else: + return None + + def get_secret( secret_name: str, default_value: Optional[Union[str, bool]] = None, @@ -257,17 +278,12 @@ def get_secret( return secret else: secret = os.environ.get(secret_name) - try: - secret_value_as_bool = ( - ast.literal_eval(secret) if secret is not None else None - ) - if isinstance(secret_value_as_bool, bool): - return secret_value_as_bool - else: - return secret - except Exception: - if default_value is not None: - return default_value + secret_value_as_bool = str_to_bool(secret) if secret is not None else None + if secret_value_as_bool is not None and isinstance( + secret_value_as_bool, bool + ): + return secret_value_as_bool + else: return secret except Exception as e: if default_value is not None: diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index de2729db7..91ed7ea4a 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -54,6 +54,7 @@ VERTEX_MODELS_TO_NOT_TEST = [ "gemini-flash-experimental", "gemini-1.5-flash-exp-0827", "gemini-pro-flash", + "gemini-1.5-flash-exp-0827", ] diff --git a/litellm/tests/test_custom_callback_router.py b/litellm/tests/test_custom_callback_router.py index 80fc096e7..6ffa97d89 100644 --- a/litellm/tests/test_custom_callback_router.py +++ b/litellm/tests/test_custom_callback_router.py @@ -38,6 +38,8 @@ from litellm.integrations.custom_logger import CustomLogger ## 1. router.completion() + router.embeddings() ## 2. proxy.completions + proxy.embeddings +litellm.num_retries = 0 + class CompletionCustomHandler( CustomLogger @@ -401,7 +403,7 @@ async def test_async_chat_azure(): "rpm": 1800, }, ] - router = Router(model_list=model_list) # type: ignore + router = Router(model_list=model_list, num_retries=0) # type: ignore response = await router.acompletion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], @@ -413,7 +415,7 @@ async def test_async_chat_azure(): ) # pre, post, success # streaming litellm.callbacks = [customHandler_streaming_azure_router] - router2 = Router(model_list=model_list) # type: ignore + router2 = Router(model_list=model_list, num_retries=0) # type: ignore response = await router2.acompletion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], @@ -443,7 +445,7 @@ async def test_async_chat_azure(): }, ] litellm.callbacks = [customHandler_failure] - router3 = Router(model_list=model_list) # type: ignore + router3 = Router(model_list=model_list, num_retries=0) # type: ignore try: response = await router3.acompletion( model="gpt-3.5-turbo", @@ -505,7 +507,7 @@ async def test_async_embedding_azure(): }, ] litellm.callbacks = [customHandler_failure] - router3 = Router(model_list=model_list) # type: ignore + router3 = Router(model_list=model_list, num_retries=0) # type: ignore try: response = await router3.aembedding( model="azure-embedding-model", input=["hello from litellm!"] @@ -678,22 +680,21 @@ async def test_rate_limit_error_callback(): pass with patch.object( - customHandler, "log_model_group_rate_limit_error", new=MagicMock() + customHandler, "log_model_group_rate_limit_error", new=AsyncMock() ) as mock_client: print( f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}" ) - for _ in range(3): - try: - _ = await router.acompletion( - model="my-test-gpt", - messages=[{"role": "user", "content": "Hey, how's it going?"}], - litellm_logging_obj=litellm_logging_obj, - ) - except (litellm.RateLimitError, ValueError): - pass + try: + _ = await router.acompletion( + model="my-test-gpt", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + litellm_logging_obj=litellm_logging_obj, + ) + except (litellm.RateLimitError, ValueError): + pass await asyncio.sleep(3) mock_client.assert_called_once() diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index ddafdb933..51bf55c9c 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -23,8 +23,9 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import Request +import litellm from litellm.caching import DualCache -from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes +from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.team_endpoints import new_team from litellm.proxy.proxy_server import chat_completion @@ -816,8 +817,6 @@ async def test_allowed_routes_admin(prisma_client, audience): raise e -from unittest.mock import AsyncMock - import pytest @@ -844,3 +843,148 @@ async def test_team_cache_update_called(): await asyncio.sleep(3) mock_call_cache.assert_awaited_once() + + +@pytest.fixture +def public_jwt_key(): + import json + + import jwt + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + # Generate a private / public key pair using RSA algorithm + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + # Get private key in PEM format + private_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Get public key in PEM format + public_key = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + public_key_obj = serialization.load_pem_public_key( + public_key, backend=default_backend() + ) + + # Convert RSA public key object to JWK (JSON Web Key) + public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) + + return {"private_key": private_key, "public_jwk": public_jwk} + + +def mock_user_object(*args, **kwargs): + print("Args: {}".format(args)) + print("kwargs: {}".format(kwargs)) + assert kwargs["user_id_upsert"] is True + + +@pytest.mark.parametrize( + "user_email, should_work", [("ishaan@berri.ai", True), ("krrish@tassle.xyz", False)] +) +@pytest.mark.asyncio +async def test_allow_access_by_email(public_jwt_key, user_email, should_work): + """ + Allow anyone with an `@xyz.com` email make a request to the proxy. + + Relevant issue: https://github.com/BerriAI/litellm/issues/5605 + """ + import jwt + from starlette.datastructures import URL + + from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth + from litellm.proxy.proxy_server import user_api_key_auth + + public_jwk = public_jwt_key["public_jwk"] + private_key = public_jwt_key["private_key"] + + # set cache + cache = DualCache() + + await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk]) + + jwt_handler = JWTHandler() + + jwt_handler.user_api_key_cache = cache + + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + user_email_jwt_field="email", + user_allowed_email_domain="berri.ai", + user_id_upsert=True, + ) + + # VALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + team_id = f"team123_{uuid.uuid4()}" + payload = { + "sub": "user123", + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm_team", + "client_id": team_id, + "aud": "litellm-proxy", + "email": user_email, + } + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + + ## team token + token = jwt.encode(payload, private_key_str, algorithm="RS256") + + ## VERIFY IT WORKS + # Expect the call to succeed + response = await jwt_handler.auth_jwt(token=token) + assert response is not None # Adjust this based on your actual response check + + ## RUN IT THROUGH USER API KEY AUTH + bearer_token = "Bearer " + token + + request = Request(scope={"type": "http"}) + + request._url = URL(url="/chat/completions") + + ## 1. INITIAL TEAM CALL - should fail + # use generated key to auth in + setattr( + litellm.proxy.proxy_server, + "general_settings", + { + "enable_jwt_auth": True, + }, + ) + setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) + setattr(litellm.proxy.proxy_server, "prisma_client", {}) + + # AsyncMock( + # return_value=LiteLLM_UserTable( + # spend=0, user_id=user_email, max_budget=None, user_email=user_email + # ) + # ), + with patch.object( + litellm.proxy.auth.user_api_key_auth, + "get_user_object", + side_effect=mock_user_object, + ) as mock_client: + if should_work: + # Expect the call to succeed + result = await user_api_key_auth(request=request, api_key=bearer_token) + assert result is not None # Adjust this based on your actual response check + else: + # Expect the call to fail + with pytest.raises( + Exception + ): # Replace with the actual exception raised on failure + resp = await user_api_key_auth(request=request, api_key=bearer_token) + print(resp) diff --git a/litellm/tests/test_lunary.py b/litellm/tests/test_lunary.py index cd068d990..d181d24c7 100644 --- a/litellm/tests/test_lunary.py +++ b/litellm/tests/test_lunary.py @@ -1,16 +1,17 @@ -import sys -import os import io +import os +import sys sys.path.insert(0, os.path.abspath("../..")) -from litellm import completion import litellm +from litellm import completion litellm.failure_callback = ["lunary"] litellm.success_callback = ["lunary"] litellm.set_verbose = True + def test_lunary_logging(): try: response = completion( @@ -24,6 +25,7 @@ def test_lunary_logging(): except Exception as e: print(e) + test_lunary_logging() @@ -37,8 +39,6 @@ def test_lunary_template(): except Exception as e: print(e) -test_lunary_template() - def test_lunary_logging_with_metadata(): try: @@ -50,19 +50,23 @@ def test_lunary_logging_with_metadata(): metadata={ "run_name": "litellmRUN", "project_name": "litellm-completion", - "tags": ["tag1", "tag2"] + "tags": ["tag1", "tag2"], }, ) print(response) except Exception as e: print(e) -test_lunary_logging_with_metadata() def test_lunary_with_tools(): import litellm - messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] + messages = [ + { + "role": "user", + "content": "What's the weather like in San Francisco, Tokyo, and Paris?", + } + ] tools = [ { "type": "function", @@ -90,13 +94,11 @@ def test_lunary_with_tools(): tools=tools, tool_choice="auto", # auto is default, but we'll be explicit ) - + response_message = response.choices[0].message print("\nLLM Response:\n", response.choices[0].message) -test_lunary_with_tools() - def test_lunary_logging_with_streaming_and_metadata(): try: response = completion( @@ -114,5 +116,3 @@ def test_lunary_logging_with_streaming_and_metadata(): continue except Exception as e: print(e) - -test_lunary_logging_with_streaming_and_metadata() diff --git a/litellm/tests/test_proxy_utils.py b/litellm/tests/test_proxy_utils.py index 63361b09a..b5aac09d1 100644 --- a/litellm/tests/test_proxy_utils.py +++ b/litellm/tests/test_proxy_utils.py @@ -11,8 +11,11 @@ import litellm sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request +from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth +from litellm.proxy.litellm_pre_call_utils import ( + _get_dynamic_logging_metadata, + add_litellm_data_to_request, +) from litellm.types.utils import SupportedCacheControls @@ -204,3 +207,87 @@ async def test_add_key_or_team_level_spend_logs_metadata_to_request( # assert ( # new_data["metadata"]["spend_logs_metadata"] == metadata["spend_logs_metadata"] # ) + + +@pytest.mark.parametrize( + "callback_vars", + [ + { + "langfuse_host": "https://us.cloud.langfuse.com", + "langfuse_public_key": "pk-lf-9636b7a6-c066", + "langfuse_secret_key": "sk-lf-7cc8b620", + }, + { + "langfuse_host": "os.environ/LANGFUSE_HOST_TEMP", + "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY_TEMP", + "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY_TEMP", + }, + ], +) +def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars): + os.environ["LANGFUSE_PUBLIC_KEY_TEMP"] = "pk-lf-9636b7a6-c066" + os.environ["LANGFUSE_SECRET_KEY_TEMP"] = "sk-lf-7cc8b620" + os.environ["LANGFUSE_HOST_TEMP"] = "https://us.cloud.langfuse.com" + user_api_key_dict = UserAPIKeyAuth( + token="6f8688eaff1d37555bb9e9a6390b6d7032b3ab2526ba0152da87128eab956432", + key_name="sk-...63Fg", + key_alias=None, + spend=0.000111, + max_budget=None, + expires=None, + models=[], + aliases={}, + config={}, + user_id=None, + team_id="ishaan-special-team_e02dd54f-f790-4755-9f93-73734f415898", + max_parallel_requests=None, + metadata={ + "logging": [ + { + "callback_name": "langfuse", + "callback_type": "success", + "callback_vars": callback_vars, + } + ] + }, + tpm_limit=None, + rpm_limit=None, + budget_duration=None, + budget_reset_at=None, + allowed_cache_controls=[], + permissions={}, + model_spend={}, + model_max_budget={}, + soft_budget_cooldown=False, + litellm_budget_table=None, + org_id=None, + team_spend=0.000132, + team_alias=None, + team_tpm_limit=None, + team_rpm_limit=None, + team_max_budget=None, + team_models=[], + team_blocked=False, + soft_budget=None, + team_model_aliases=None, + team_member_spend=None, + team_member=None, + team_metadata={}, + end_user_id=None, + end_user_tpm_limit=None, + end_user_rpm_limit=None, + end_user_max_budget=None, + last_refreshed_at=1726101560.967527, + api_key="7c305cc48fe72272700dc0d67dc691c2d1f2807490ef5eb2ee1d3a3ca86e12b1", + user_role=LitellmUserRoles.INTERNAL_USER, + allowed_model_region=None, + parent_otel_span=None, + rpm_limit_per_model=None, + tpm_limit_per_model=None, + ) + callbacks = _get_dynamic_logging_metadata(user_api_key_dict=user_api_key_dict) + + assert callbacks is not None + + for var in callbacks.callback_vars.values(): + assert "os.environ" not in var diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index fd89130fe..05d9f9f76 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -2121,7 +2121,7 @@ def test_router_cooldown_api_connection_error(): except litellm.APIConnectionError as e: assert ( Router()._is_cooldown_required( - exception_status=e.code, exception_str=str(e) + model_id="", exception_status=e.code, exception_str=str(e) ) is False ) @@ -2272,7 +2272,13 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode): "litellm_params": { "model": "openai/text-embedding-ada-002", }, - } + }, + { + "model_name": "text-embedding-ada-002", + "litellm_params": { + "model": "openai/text-embedding-ada-002", + }, + }, ] ) diff --git a/litellm/tests/test_router_cooldowns.py b/litellm/tests/test_router_cooldowns.py index 3eef6e542..ac92dfbf0 100644 --- a/litellm/tests/test_router_cooldowns.py +++ b/litellm/tests/test_router_cooldowns.py @@ -21,6 +21,7 @@ import openai import litellm from litellm import Router from litellm.integrations.custom_logger import CustomLogger +from litellm.types.router import DeploymentTypedDict, LiteLLMParamsTypedDict @pytest.mark.asyncio @@ -112,3 +113,40 @@ async def test_dynamic_cooldowns(): assert "cooldown_time" in tmp_mock.call_args[0][0]["litellm_params"] assert tmp_mock.call_args[0][0]["litellm_params"]["cooldown_time"] == 0 + + +@pytest.mark.parametrize("num_deployments", [1, 2]) +def test_single_deployment_no_cooldowns(num_deployments): + """ + Do not cooldown on single deployment. + + Cooldown on multiple deployments. + """ + model_list = [] + for i in range(num_deployments): + model = DeploymentTypedDict( + model_name="gpt-3.5-turbo", + litellm_params=LiteLLMParamsTypedDict( + model="gpt-3.5-turbo", + ), + ) + model_list.append(model) + + router = Router(model_list=model_list, allowed_fails=0, num_retries=0) + + with patch.object( + router.cooldown_cache, "add_deployment_to_cooldown", new=MagicMock() + ) as mock_client: + try: + router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + mock_response="litellm.RateLimitError", + ) + except litellm.RateLimitError: + pass + + if num_deployments == 1: + mock_client.assert_not_called() + else: + mock_client.assert_called_once() diff --git a/litellm/tests/test_router_retries.py b/litellm/tests/test_router_retries.py index f0503cd3f..f4574212d 100644 --- a/litellm/tests/test_router_retries.py +++ b/litellm/tests/test_router_retries.py @@ -89,6 +89,17 @@ async def test_router_retries_errors(sync_mode, error_type): "tpm": 240000, "rpm": 1800, }, + { + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": _api_key, + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, ] router = Router(model_list=model_list, allowed_fails=3) diff --git a/litellm/tests/test_tpm_rpm_routing_v2.py b/litellm/tests/test_tpm_rpm_routing_v2.py index 1f3de0910..259bd0ee0 100644 --- a/litellm/tests/test_tpm_rpm_routing_v2.py +++ b/litellm/tests/test_tpm_rpm_routing_v2.py @@ -17,6 +17,8 @@ import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from unittest.mock import AsyncMock, MagicMock, patch + import pytest import litellm @@ -459,3 +461,138 @@ async def test_router_completion_streaming(): - Unit test for sync 'pre_call_checks' - Unit test for async 'async_pre_call_checks' """ + + +@pytest.mark.asyncio +async def test_router_caching_ttl(): + """ + Confirm caching ttl's work as expected. + + Relevant issue: https://github.com/BerriAI/litellm/issues/5609 + """ + messages = [ + {"role": "user", "content": "Hello, can you generate a 500 words poem?"} + ] + model = "azure-model" + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "tpm": 1440, + "mock_response": "Hello world", + }, + "model_info": {"id": 1}, + } + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing-v2", + set_verbose=False, + redis_host=os.getenv("REDIS_HOST"), + redis_password=os.getenv("REDIS_PASSWORD"), + redis_port=os.getenv("REDIS_PORT"), + ) + + assert router.cache.redis_cache is not None + + increment_cache_kwargs = {} + with patch.object( + router.cache.redis_cache, + "async_increment", + new=AsyncMock(), + ) as mock_client: + await router.acompletion(model=model, messages=messages) + + mock_client.assert_called_once() + print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}") + print(f"mock_client.call_args.args: {mock_client.call_args.args}") + + increment_cache_kwargs = { + "key": mock_client.call_args.args[0], + "value": mock_client.call_args.args[1], + "ttl": mock_client.call_args.kwargs["ttl"], + } + + assert mock_client.call_args.kwargs["ttl"] == 60 + + ## call redis async increment and check if ttl correctly set + await router.cache.redis_cache.async_increment(**increment_cache_kwargs) + + _redis_client = router.cache.redis_cache.init_async_client() + + async with _redis_client as redis_client: + current_ttl = await redis_client.ttl(increment_cache_kwargs["key"]) + + assert current_ttl >= 0 + + print(f"current_ttl: {current_ttl}") + + +def test_router_caching_ttl_sync(): + """ + Confirm caching ttl's work as expected. + + Relevant issue: https://github.com/BerriAI/litellm/issues/5609 + """ + messages = [ + {"role": "user", "content": "Hello, can you generate a 500 words poem?"} + ] + model = "azure-model" + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "tpm": 1440, + "mock_response": "Hello world", + }, + "model_info": {"id": 1}, + } + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing-v2", + set_verbose=False, + redis_host=os.getenv("REDIS_HOST"), + redis_password=os.getenv("REDIS_PASSWORD"), + redis_port=os.getenv("REDIS_PORT"), + ) + + assert router.cache.redis_cache is not None + + increment_cache_kwargs = {} + with patch.object( + router.cache.redis_cache, + "increment_cache", + new=MagicMock(), + ) as mock_client: + router.completion(model=model, messages=messages) + + print(mock_client.call_args_list) + mock_client.assert_called() + print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}") + print(f"mock_client.call_args.args: {mock_client.call_args.args}") + + increment_cache_kwargs = { + "key": mock_client.call_args.args[0], + "value": mock_client.call_args.args[1], + "ttl": mock_client.call_args.kwargs["ttl"], + } + + assert mock_client.call_args.kwargs["ttl"] == 60 + + ## call redis async increment and check if ttl correctly set + router.cache.redis_cache.increment_cache(**increment_cache_kwargs) + + _redis_client = router.cache.redis_cache.redis_client + + current_ttl = _redis_client.ttl(increment_cache_kwargs["key"]) + + assert current_ttl >= 0 + + print(f"current_ttl: {current_ttl}") diff --git a/litellm/tests/test_traceloop.py b/litellm/tests/test_traceloop.py index bcc120323..74d58228e 100644 --- a/litellm/tests/test_traceloop.py +++ b/litellm/tests/test_traceloop.py @@ -4,7 +4,6 @@ import time import pytest from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from traceloop.sdk import Traceloop import litellm @@ -13,6 +12,8 @@ sys.path.insert(0, os.path.abspath("../..")) @pytest.fixture() def exporter(): + from traceloop.sdk import Traceloop + exporter = InMemorySpanExporter() Traceloop.init( app_name="test_litellm",