From 00e49380df014d194577ff4150848645d22e4cf3 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 12 Apr 2025 19:30:48 -0700 Subject: [PATCH] Litellm UI qa 04 12 2025 p1 (#9955) * fix(model_info_view.tsx): cleanup text * fix(key_management_endpoints.py): fix filtering litellm-dashboard keys for internal users * fix(proxy_track_cost_callback.py): prevent flooding spend logs with admin endpoint errors * test: add unit testing for logic * test(test_auth_exception_handler.py): add more unit testing * fix(router.py): correctly handle retrieving model info on get_model_group_info fixes issue where model hub was showing None prices * fix: fix linting errors --- .../litellm_core_utils/get_model_cost_map.py | 2 +- .../proxy/_experimental/out/onboarding.html | 1 - litellm/proxy/_types.py | 67 ++++++++++--------- litellm/proxy/auth/auth_exception_handler.py | 2 + litellm/proxy/auth/user_api_key_auth.py | 3 + .../proxy/hooks/proxy_track_cost_callback.py | 6 ++ .../key_management_endpoints.py | 41 ++++++------ litellm/router.py | 47 +++++++------ .../proxy/auth/test_auth_exception_handler.py | 44 +++++++++++- .../hooks/test_proxy_track_cost_callback.py | 45 +++++++++++++ .../test_key_management_endpoints.py | 48 +++++++++++++ tests/local_testing/test_router.py | 21 ++++++ .../src/components/model_info_view.tsx | 2 +- 13 files changed, 249 insertions(+), 80 deletions(-) delete mode 100644 litellm/proxy/_experimental/out/onboarding.html create mode 100644 tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py diff --git a/litellm/litellm_core_utils/get_model_cost_map.py b/litellm/litellm_core_utils/get_model_cost_map.py index b8bdaee19c..b6a3a243c4 100644 --- a/litellm/litellm_core_utils/get_model_cost_map.py +++ b/litellm/litellm_core_utils/get_model_cost_map.py @@ -13,7 +13,7 @@ import os import httpx -def get_model_cost_map(url: str): +def get_model_cost_map(url: str) -> dict: if ( os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True" diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 1b1ad5c2cc..0000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 2ea6dcf038..1e7184db93 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -644,9 +644,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): allowed_cache_controls: Optional[list] = [] config: Optional[dict] = {} permissions: Optional[dict] = {} - model_max_budget: Optional[dict] = ( - {} - ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} + model_max_budget: Optional[ + dict + ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} model_config = ConfigDict(protected_namespaces=()) model_rpm_limit: Optional[dict] = None @@ -902,12 +902,12 @@ class NewCustomerRequest(BudgetNewRequest): alias: Optional[str] = None # human-friendly alias blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model @model_validator(mode="before") @classmethod @@ -929,12 +929,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model class DeleteCustomerRequest(LiteLLMPydanticObjectBase): @@ -1070,9 +1070,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase): callback_name: str - callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( - "success_and_failure" - ) + callback_type: Optional[ + Literal["success", "failure", "success_and_failure"] + ] = "success_and_failure" callback_vars: Dict[str, str] @model_validator(mode="before") @@ -1329,9 +1329,9 @@ class ConfigList(LiteLLMPydanticObjectBase): stored_in_db: Optional[bool] field_default_value: Any premium_field: bool = False - nested_fields: Optional[List[FieldDetail]] = ( - None # For nested dictionary or Pydantic fields - ) + nested_fields: Optional[ + List[FieldDetail] + ] = None # For nested dictionary or Pydantic fields class ConfigGeneralSettings(LiteLLMPydanticObjectBase): @@ -1558,6 +1558,7 @@ class UserAPIKeyAuth( user_tpm_limit: Optional[int] = None user_rpm_limit: Optional[int] = None user_email: Optional[str] = None + request_route: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -1597,9 +1598,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): budget_id: Optional[str] = None created_at: datetime updated_at: datetime - user: Optional[Any] = ( - None # You might want to replace 'Any' with a more specific type if available - ) + user: Optional[ + Any + ] = None # You might want to replace 'Any' with a more specific type if available litellm_budget_table: Optional[LiteLLM_BudgetTable] = None model_config = ConfigDict(protected_namespaces=()) @@ -2345,9 +2346,9 @@ class TeamModelDeleteRequest(BaseModel): # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str - max_budget_in_organization: Optional[float] = ( - None # Users max budget within the organization - ) + max_budget_in_organization: Optional[ + float + ] = None # Users max budget within the organization class OrganizationMemberDeleteRequest(MemberDeleteRequest): @@ -2536,9 +2537,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): Maps provider names to their budget configs. """ - providers: Dict[str, ProviderBudgetResponseObject] = ( - {} - ) # Dictionary mapping provider names to their budget configurations + providers: Dict[ + str, ProviderBudgetResponseObject + ] = {} # Dictionary mapping provider names to their budget configurations class ProxyStateVariables(TypedDict): @@ -2666,9 +2667,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): enforce_rbac: bool = False roles_jwt_field: Optional[str] = None # v2 on role mappings role_mappings: Optional[List[RoleMapping]] = None - object_id_jwt_field: Optional[str] = ( - None # can be either user / team, inferred from the role mapping - ) + object_id_jwt_field: Optional[ + str + ] = None # can be either user / team, inferred from the role mapping scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False enforce_team_based_model_access: bool = False diff --git a/litellm/proxy/auth/auth_exception_handler.py b/litellm/proxy/auth/auth_exception_handler.py index 7c97655141..268e3bb1b2 100644 --- a/litellm/proxy/auth/auth_exception_handler.py +++ b/litellm/proxy/auth/auth_exception_handler.py @@ -68,6 +68,7 @@ class UserAPIKeyAuthExceptionHandler: key_name="failed-to-connect-to-db", token="failed-to-connect-to-db", user_id=litellm_proxy_admin_name, + request_route=route, ) else: # raise the exception to the caller @@ -87,6 +88,7 @@ class UserAPIKeyAuthExceptionHandler: user_api_key_dict = UserAPIKeyAuth( parent_otel_span=parent_otel_span, api_key=api_key, + request_route=route, ) asyncio.create_task( proxy_logging_obj.post_call_failure_hook( diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 1b140bdaab..97e9fb8c73 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -1023,6 +1023,7 @@ async def user_api_key_auth( """ request_data = await _read_request_body(request=request) + route: str = get_request_route(request=request) user_api_key_auth_obj = await _user_api_key_auth_builder( request=request, @@ -1038,6 +1039,8 @@ async def user_api_key_auth( if end_user_id is not None: user_api_key_auth_obj.end_user_id = end_user_id + user_api_key_auth_obj.request_route = route + return user_api_key_auth_obj diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 2fd587f46c..cf0e0a07ed 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -13,6 +13,7 @@ from litellm.litellm_core_utils.core_helpers import ( from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.auth_checks import log_db_metrics +from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.utils import ProxyUpdateSpend from litellm.types.utils import ( StandardLoggingPayload, @@ -33,8 +34,13 @@ class _ProxyDBLogger(CustomLogger): original_exception: Exception, user_api_key_dict: UserAPIKeyAuth, ): + request_route = user_api_key_dict.request_route if _ProxyDBLogger._should_track_errors_in_db() is False: return + elif request_route is not None and not RouteChecks.is_llm_api_route( + route=request_route + ): + return from litellm.proxy.proxy_server import proxy_logging_obj diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 11f6e5e603..d37163d2ef 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -577,9 +577,9 @@ async def generate_key_fn( # noqa: PLR0915 request_type="key", **data_json, table_name="key" ) - response["soft_budget"] = ( - data.soft_budget - ) # include the user-input soft budget in the response + response[ + "soft_budget" + ] = data.soft_budget # include the user-input soft budget in the response response = GenerateKeyResponse(**response) @@ -1467,10 +1467,10 @@ async def delete_verification_tokens( try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": tokens}} - ) + _keys_being_deleted: List[ + LiteLLM_VerificationToken + ] = await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} ) # Assuming 'db' is your Prisma Client instance @@ -1572,9 +1572,9 @@ async def _rotate_master_key( from litellm.proxy.proxy_server import proxy_config try: - models: Optional[List] = ( - await prisma_client.db.litellm_proxymodeltable.find_many() - ) + models: Optional[ + List + ] = await prisma_client.db.litellm_proxymodeltable.find_many() except Exception: models = None # 2. process model table @@ -1861,11 +1861,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[BaseModel] = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, - ) + complete_user_info_db_obj: Optional[ + BaseModel + ] = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, ) if complete_user_info_db_obj is None: @@ -1926,10 +1926,10 @@ async def get_admin_team_ids( if complete_user_info is None: return [] # Get all teams that user is an admin of - teams: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} - ) + teams: Optional[ + List[BaseModel] + ] = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} ) if teams is None: return [] @@ -2080,7 +2080,6 @@ async def _list_key_helper( "total_pages": int, } """ - # Prepare filter conditions where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {} where.update(_get_condition_to_filter_out_ui_session_tokens()) @@ -2110,7 +2109,7 @@ async def _list_key_helper( # Combine conditions with OR if we have multiple conditions if len(or_conditions) > 1: - where["OR"] = or_conditions + where = {"AND": [where, {"OR": or_conditions}]} elif len(or_conditions) == 1: where.update(or_conditions[0]) diff --git a/litellm/router.py b/litellm/router.py index 4a466f4119..c934b1e9a8 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -339,9 +339,9 @@ class Router: ) # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_latency_map = {} ### CACHING ### - cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = ( - "local" # default to an in-memory cache - ) + cache_type: Literal[ + "local", "redis", "redis-semantic", "s3", "disk" + ] = "local" # default to an in-memory cache redis_cache = None cache_config: Dict[str, Any] = {} @@ -562,9 +562,9 @@ class Router: ) ) - self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( - model_group_retry_policy - ) + self.model_group_retry_policy: Optional[ + Dict[str, RetryPolicy] + ] = model_group_retry_policy self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None if allowed_fails_policy is not None: @@ -1105,9 +1105,9 @@ class Router: """ Adds default litellm params to kwargs, if set. """ - self.default_litellm_params[metadata_variable_name] = ( - self.default_litellm_params.pop("metadata", {}) - ) + self.default_litellm_params[ + metadata_variable_name + ] = self.default_litellm_params.pop("metadata", {}) for k, v in self.default_litellm_params.items(): if ( k not in kwargs and v is not None @@ -3243,11 +3243,11 @@ class Router: if isinstance(e, litellm.ContextWindowExceededError): if context_window_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=context_window_fallbacks, - model_group=model_group, - ) + fallback_model_group: Optional[ + List[str] + ] = self._get_fallback_model_group_from_fallbacks( + fallbacks=context_window_fallbacks, + model_group=model_group, ) if fallback_model_group is None: raise original_exception @@ -3279,11 +3279,11 @@ class Router: e.message += "\n{}".format(error_message) elif isinstance(e, litellm.ContentPolicyViolationError): if content_policy_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=content_policy_fallbacks, - model_group=model_group, - ) + fallback_model_group: Optional[ + List[str] + ] = self._get_fallback_model_group_from_fallbacks( + fallbacks=content_policy_fallbacks, + model_group=model_group, ) if fallback_model_group is None: raise original_exception @@ -4853,10 +4853,11 @@ class Router: from litellm.utils import _update_dictionary model_info: Optional[ModelInfo] = None + custom_model_info: Optional[dict] = None litellm_model_name_model_info: Optional[ModelInfo] = None try: - model_info = litellm.get_model_info(model=model_id) + custom_model_info = litellm.model_cost.get(model_id) except Exception: pass @@ -4865,14 +4866,16 @@ class Router: except Exception: pass - if model_info is not None and litellm_model_name_model_info is not None: + if custom_model_info is not None and litellm_model_name_model_info is not None: model_info = cast( ModelInfo, _update_dictionary( cast(dict, litellm_model_name_model_info).copy(), - cast(dict, model_info), + custom_model_info, ), ) + elif litellm_model_name_model_info is not None: + model_info = litellm_model_name_model_info return model_info diff --git a/tests/litellm/proxy/auth/test_auth_exception_handler.py b/tests/litellm/proxy/auth/test_auth_exception_handler.py index 224bf24b57..3e780c6ee9 100644 --- a/tests/litellm/proxy/auth/test_auth_exception_handler.py +++ b/tests/litellm/proxy/auth/test_auth_exception_handler.py @@ -2,7 +2,7 @@ import asyncio import json import os import sys -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import HTTPException, Request, status @@ -110,3 +110,45 @@ async def test_handle_authentication_error_budget_exceeded(): ) assert exc_info.value.type == ProxyErrorTypes.budget_exceeded + + +@pytest.mark.asyncio +async def test_route_passed_to_post_call_failure_hook(): + """ + This route is used by proxy track_cost_callback's async_post_call_failure_hook to check if the route is an LLM route + """ + handler = UserAPIKeyAuthExceptionHandler() + + # Mock request and other dependencies + mock_request = MagicMock() + mock_request_data = {} + test_route = "/custom/route" + mock_span = None + mock_api_key = "test-key" + + # Mock proxy_logging_obj.post_call_failure_hook + with patch( + "litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook", + new_callable=AsyncMock, + ) as mock_post_call_failure_hook: + # Test with DB connection error + with patch( + "litellm.proxy.proxy_server.general_settings", + {"allow_requests_on_db_unavailable": False}, + ): + try: + await handler._handle_authentication_error( + PrismaError(), + mock_request, + mock_request_data, + test_route, + mock_span, + mock_api_key, + ) + except Exception as e: + pass + asyncio.sleep(1) + # Verify post_call_failure_hook was called with the correct route + mock_post_call_failure_hook.assert_called_once() + call_args = mock_post_call_failure_hook.call_args[1] + assert call_args["user_api_key_dict"].request_route == test_route diff --git a/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py b/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py index 8850436329..cb6d90103f 100644 --- a/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py +++ b/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py @@ -81,3 +81,48 @@ async def test_async_post_call_failure_hook(): assert metadata["status"] == "failure" assert "error_information" in metadata assert metadata["original_key"] == "original_value" + + +@pytest.mark.asyncio +async def test_async_post_call_failure_hook_non_llm_route(): + # Setup + logger = _ProxyDBLogger() + + # Mock user_api_key_dict with a non-LLM route + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", + key_alias="test_alias", + user_email="test@example.com", + user_id="test_user_id", + team_id="test_team_id", + org_id="test_org_id", + team_alias="test_team_alias", + end_user_id="test_end_user_id", + request_route="/custom/route", # Non-LLM route + ) + + # Mock request data + request_data = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": {"original_key": "original_value"}, + "proxy_server_request": {"request_id": "test_request_id"}, + } + + # Mock exception + original_exception = Exception("Test exception") + + # Mock update_database function + with patch( + "litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter.update_database", + new_callable=AsyncMock, + ) as mock_update_database: + # Call the method + await logger.async_post_call_failure_hook( + request_data=request_data, + original_exception=original_exception, + user_api_key_dict=user_api_key_dict, + ) + + # Assert that update_database was NOT called for non-LLM routes + mock_update_database.assert_not_called() diff --git a/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py new file mode 100644 index 0000000000..51bbbb49c4 --- /dev/null +++ b/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -0,0 +1,48 @@ +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path + +from unittest.mock import AsyncMock, MagicMock + +from litellm.proxy.management_endpoints.key_management_endpoints import _list_key_helper +from litellm.proxy.proxy_server import app + +client = TestClient(app) + + +@pytest.mark.asyncio +async def test_list_keys(): + mock_prisma_client = AsyncMock() + mock_find_many = AsyncMock(return_value=[]) + mock_prisma_client.db.litellm_verificationtoken.find_many = mock_find_many + args = { + "prisma_client": mock_prisma_client, + "page": 1, + "size": 50, + "user_id": "cda88cb4-cc2c-4e8c-b871-dc71ca111b00", + "team_id": None, + "organization_id": None, + "key_alias": None, + "exclude_team_id": None, + "return_full_object": True, + "admin_team_ids": ["28bd3181-02c5-48f2-b408-ce790fb3d5ba"], + } + try: + result = await _list_key_helper(**args) + except Exception as e: + print(f"error: {e}") + + mock_find_many.assert_called_once() + + where_condition = mock_find_many.call_args.kwargs["where"] + print(f"where_condition: {where_condition}") + assert json.dumps({"team_id": {"not": "litellm-dashboard"}}) in json.dumps( + where_condition + ) diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 68a79f94a6..13eaeb09ab 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2767,3 +2767,24 @@ def test_router_dynamic_credentials(): deployment = router.get_deployment(model_id=original_model_id) assert deployment is not None assert deployment.litellm_params.api_key == original_api_key + + +def test_router_get_model_group_info(): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo"}, + }, + { + "model_name": "gpt-4", + "litellm_params": {"model": "gpt-4"}, + }, + ], + ) + + model_group_info = router.get_model_group_info(model_group="gpt-4") + assert model_group_info is not None + assert model_group_info.model_group == "gpt-4" + assert model_group_info.input_cost_per_token > 0 + assert model_group_info.output_cost_per_token > 0 \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/model_info_view.tsx b/ui/litellm-dashboard/src/components/model_info_view.tsx index 6c626300a3..22c4fefc99 100644 --- a/ui/litellm-dashboard/src/components/model_info_view.tsx +++ b/ui/litellm-dashboard/src/components/model_info_view.tsx @@ -448,7 +448,7 @@ export default function ModelInfoView({
- RPM VVV(Requests per Minute) + RPM (Requests per Minute) {isEditing ? (