diff --git a/litellm/litellm_core_utils/get_model_cost_map.py b/litellm/litellm_core_utils/get_model_cost_map.py
index b8bdaee19c..b6a3a243c4 100644
--- a/litellm/litellm_core_utils/get_model_cost_map.py
+++ b/litellm/litellm_core_utils/get_model_cost_map.py
@@ -13,7 +13,7 @@ import os
import httpx
-def get_model_cost_map(url: str):
+def get_model_cost_map(url: str) -> dict:
if (
os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False)
or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True"
diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html
deleted file mode 100644
index 1b1ad5c2cc..0000000000
--- a/litellm/proxy/_experimental/out/onboarding.html
+++ /dev/null
@@ -1 +0,0 @@
-
LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index 2ea6dcf038..1e7184db93 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -644,9 +644,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
allowed_cache_controls: Optional[list] = []
config: Optional[dict] = {}
permissions: Optional[dict] = {}
- model_max_budget: Optional[dict] = (
- {}
- ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
+ model_max_budget: Optional[
+ dict
+ ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_config = ConfigDict(protected_namespaces=())
model_rpm_limit: Optional[dict] = None
@@ -902,12 +902,12 @@ class NewCustomerRequest(BudgetNewRequest):
alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user
budget_id: Optional[str] = None # give either a budget_id or max_budget
- allowed_model_region: Optional[AllowedModelRegion] = (
- None # require all user requests to use models in this specific region
- )
- default_model: Optional[str] = (
- None # if no equivalent model in allowed region - default all requests to this model
- )
+ allowed_model_region: Optional[
+ AllowedModelRegion
+ ] = None # require all user requests to use models in this specific region
+ default_model: Optional[
+ str
+ ] = None # if no equivalent model in allowed region - default all requests to this model
@model_validator(mode="before")
@classmethod
@@ -929,12 +929,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
- allowed_model_region: Optional[AllowedModelRegion] = (
- None # require all user requests to use models in this specific region
- )
- default_model: Optional[str] = (
- None # if no equivalent model in allowed region - default all requests to this model
- )
+ allowed_model_region: Optional[
+ AllowedModelRegion
+ ] = None # require all user requests to use models in this specific region
+ default_model: Optional[
+ str
+ ] = None # if no equivalent model in allowed region - default all requests to this model
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
@@ -1070,9 +1070,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
class AddTeamCallback(LiteLLMPydanticObjectBase):
callback_name: str
- callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
- "success_and_failure"
- )
+ callback_type: Optional[
+ Literal["success", "failure", "success_and_failure"]
+ ] = "success_and_failure"
callback_vars: Dict[str, str]
@model_validator(mode="before")
@@ -1329,9 +1329,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
stored_in_db: Optional[bool]
field_default_value: Any
premium_field: bool = False
- nested_fields: Optional[List[FieldDetail]] = (
- None # For nested dictionary or Pydantic fields
- )
+ nested_fields: Optional[
+ List[FieldDetail]
+ ] = None # For nested dictionary or Pydantic fields
class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
@@ -1558,6 +1558,7 @@ class UserAPIKeyAuth(
user_tpm_limit: Optional[int] = None
user_rpm_limit: Optional[int] = None
user_email: Optional[str] = None
+ request_route: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -1597,9 +1598,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
budget_id: Optional[str] = None
created_at: datetime
updated_at: datetime
- user: Optional[Any] = (
- None # You might want to replace 'Any' with a more specific type if available
- )
+ user: Optional[
+ Any
+ ] = None # You might want to replace 'Any' with a more specific type if available
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
model_config = ConfigDict(protected_namespaces=())
@@ -2345,9 +2346,9 @@ class TeamModelDeleteRequest(BaseModel):
# Organization Member Requests
class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str
- max_budget_in_organization: Optional[float] = (
- None # Users max budget within the organization
- )
+ max_budget_in_organization: Optional[
+ float
+ ] = None # Users max budget within the organization
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
@@ -2536,9 +2537,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
Maps provider names to their budget configs.
"""
- providers: Dict[str, ProviderBudgetResponseObject] = (
- {}
- ) # Dictionary mapping provider names to their budget configurations
+ providers: Dict[
+ str, ProviderBudgetResponseObject
+ ] = {} # Dictionary mapping provider names to their budget configurations
class ProxyStateVariables(TypedDict):
@@ -2666,9 +2667,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
enforce_rbac: bool = False
roles_jwt_field: Optional[str] = None # v2 on role mappings
role_mappings: Optional[List[RoleMapping]] = None
- object_id_jwt_field: Optional[str] = (
- None # can be either user / team, inferred from the role mapping
- )
+ object_id_jwt_field: Optional[
+ str
+ ] = None # can be either user / team, inferred from the role mapping
scope_mappings: Optional[List[ScopeMapping]] = None
enforce_scope_based_access: bool = False
enforce_team_based_model_access: bool = False
diff --git a/litellm/proxy/auth/auth_exception_handler.py b/litellm/proxy/auth/auth_exception_handler.py
index 7c97655141..268e3bb1b2 100644
--- a/litellm/proxy/auth/auth_exception_handler.py
+++ b/litellm/proxy/auth/auth_exception_handler.py
@@ -68,6 +68,7 @@ class UserAPIKeyAuthExceptionHandler:
key_name="failed-to-connect-to-db",
token="failed-to-connect-to-db",
user_id=litellm_proxy_admin_name,
+ request_route=route,
)
else:
# raise the exception to the caller
@@ -87,6 +88,7 @@ class UserAPIKeyAuthExceptionHandler:
user_api_key_dict = UserAPIKeyAuth(
parent_otel_span=parent_otel_span,
api_key=api_key,
+ request_route=route,
)
asyncio.create_task(
proxy_logging_obj.post_call_failure_hook(
diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py
index 1b140bdaab..97e9fb8c73 100644
--- a/litellm/proxy/auth/user_api_key_auth.py
+++ b/litellm/proxy/auth/user_api_key_auth.py
@@ -1023,6 +1023,7 @@ async def user_api_key_auth(
"""
request_data = await _read_request_body(request=request)
+ route: str = get_request_route(request=request)
user_api_key_auth_obj = await _user_api_key_auth_builder(
request=request,
@@ -1038,6 +1039,8 @@ async def user_api_key_auth(
if end_user_id is not None:
user_api_key_auth_obj.end_user_id = end_user_id
+ user_api_key_auth_obj.request_route = route
+
return user_api_key_auth_obj
diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py
index 2fd587f46c..cf0e0a07ed 100644
--- a/litellm/proxy/hooks/proxy_track_cost_callback.py
+++ b/litellm/proxy/hooks/proxy_track_cost_callback.py
@@ -13,6 +13,7 @@ from litellm.litellm_core_utils.core_helpers import (
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_checks import log_db_metrics
+from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import ProxyUpdateSpend
from litellm.types.utils import (
StandardLoggingPayload,
@@ -33,8 +34,13 @@ class _ProxyDBLogger(CustomLogger):
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
):
+ request_route = user_api_key_dict.request_route
if _ProxyDBLogger._should_track_errors_in_db() is False:
return
+ elif request_route is not None and not RouteChecks.is_llm_api_route(
+ route=request_route
+ ):
+ return
from litellm.proxy.proxy_server import proxy_logging_obj
diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py
index 11f6e5e603..d37163d2ef 100644
--- a/litellm/proxy/management_endpoints/key_management_endpoints.py
+++ b/litellm/proxy/management_endpoints/key_management_endpoints.py
@@ -577,9 +577,9 @@ async def generate_key_fn( # noqa: PLR0915
request_type="key", **data_json, table_name="key"
)
- response["soft_budget"] = (
- data.soft_budget
- ) # include the user-input soft budget in the response
+ response[
+ "soft_budget"
+ ] = data.soft_budget # include the user-input soft budget in the response
response = GenerateKeyResponse(**response)
@@ -1467,10 +1467,10 @@ async def delete_verification_tokens(
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
- _keys_being_deleted: List[LiteLLM_VerificationToken] = (
- await prisma_client.db.litellm_verificationtoken.find_many(
- where={"token": {"in": tokens}}
- )
+ _keys_being_deleted: List[
+ LiteLLM_VerificationToken
+ ] = await prisma_client.db.litellm_verificationtoken.find_many(
+ where={"token": {"in": tokens}}
)
# Assuming 'db' is your Prisma Client instance
@@ -1572,9 +1572,9 @@ async def _rotate_master_key(
from litellm.proxy.proxy_server import proxy_config
try:
- models: Optional[List] = (
- await prisma_client.db.litellm_proxymodeltable.find_many()
- )
+ models: Optional[
+ List
+ ] = await prisma_client.db.litellm_proxymodeltable.find_many()
except Exception:
models = None
# 2. process model table
@@ -1861,11 +1861,11 @@ async def validate_key_list_check(
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
- complete_user_info_db_obj: Optional[BaseModel] = (
- await prisma_client.db.litellm_usertable.find_unique(
- where={"user_id": user_api_key_dict.user_id},
- include={"organization_memberships": True},
- )
+ complete_user_info_db_obj: Optional[
+ BaseModel
+ ] = await prisma_client.db.litellm_usertable.find_unique(
+ where={"user_id": user_api_key_dict.user_id},
+ include={"organization_memberships": True},
)
if complete_user_info_db_obj is None:
@@ -1926,10 +1926,10 @@ async def get_admin_team_ids(
if complete_user_info is None:
return []
# Get all teams that user is an admin of
- teams: Optional[List[BaseModel]] = (
- await prisma_client.db.litellm_teamtable.find_many(
- where={"team_id": {"in": complete_user_info.teams}}
- )
+ teams: Optional[
+ List[BaseModel]
+ ] = await prisma_client.db.litellm_teamtable.find_many(
+ where={"team_id": {"in": complete_user_info.teams}}
)
if teams is None:
return []
@@ -2080,7 +2080,6 @@ async def _list_key_helper(
"total_pages": int,
}
"""
-
# Prepare filter conditions
where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {}
where.update(_get_condition_to_filter_out_ui_session_tokens())
@@ -2110,7 +2109,7 @@ async def _list_key_helper(
# Combine conditions with OR if we have multiple conditions
if len(or_conditions) > 1:
- where["OR"] = or_conditions
+ where = {"AND": [where, {"OR": or_conditions}]}
elif len(or_conditions) == 1:
where.update(or_conditions[0])
diff --git a/litellm/router.py b/litellm/router.py
index 4a466f4119..c934b1e9a8 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -339,9 +339,9 @@ class Router:
) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
### CACHING ###
- cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
- "local" # default to an in-memory cache
- )
+ cache_type: Literal[
+ "local", "redis", "redis-semantic", "s3", "disk"
+ ] = "local" # default to an in-memory cache
redis_cache = None
cache_config: Dict[str, Any] = {}
@@ -562,9 +562,9 @@ class Router:
)
)
- self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
- model_group_retry_policy
- )
+ self.model_group_retry_policy: Optional[
+ Dict[str, RetryPolicy]
+ ] = model_group_retry_policy
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
if allowed_fails_policy is not None:
@@ -1105,9 +1105,9 @@ class Router:
"""
Adds default litellm params to kwargs, if set.
"""
- self.default_litellm_params[metadata_variable_name] = (
- self.default_litellm_params.pop("metadata", {})
- )
+ self.default_litellm_params[
+ metadata_variable_name
+ ] = self.default_litellm_params.pop("metadata", {})
for k, v in self.default_litellm_params.items():
if (
k not in kwargs and v is not None
@@ -3243,11 +3243,11 @@ class Router:
if isinstance(e, litellm.ContextWindowExceededError):
if context_window_fallbacks is not None:
- fallback_model_group: Optional[List[str]] = (
- self._get_fallback_model_group_from_fallbacks(
- fallbacks=context_window_fallbacks,
- model_group=model_group,
- )
+ fallback_model_group: Optional[
+ List[str]
+ ] = self._get_fallback_model_group_from_fallbacks(
+ fallbacks=context_window_fallbacks,
+ model_group=model_group,
)
if fallback_model_group is None:
raise original_exception
@@ -3279,11 +3279,11 @@ class Router:
e.message += "\n{}".format(error_message)
elif isinstance(e, litellm.ContentPolicyViolationError):
if content_policy_fallbacks is not None:
- fallback_model_group: Optional[List[str]] = (
- self._get_fallback_model_group_from_fallbacks(
- fallbacks=content_policy_fallbacks,
- model_group=model_group,
- )
+ fallback_model_group: Optional[
+ List[str]
+ ] = self._get_fallback_model_group_from_fallbacks(
+ fallbacks=content_policy_fallbacks,
+ model_group=model_group,
)
if fallback_model_group is None:
raise original_exception
@@ -4853,10 +4853,11 @@ class Router:
from litellm.utils import _update_dictionary
model_info: Optional[ModelInfo] = None
+ custom_model_info: Optional[dict] = None
litellm_model_name_model_info: Optional[ModelInfo] = None
try:
- model_info = litellm.get_model_info(model=model_id)
+ custom_model_info = litellm.model_cost.get(model_id)
except Exception:
pass
@@ -4865,14 +4866,16 @@ class Router:
except Exception:
pass
- if model_info is not None and litellm_model_name_model_info is not None:
+ if custom_model_info is not None and litellm_model_name_model_info is not None:
model_info = cast(
ModelInfo,
_update_dictionary(
cast(dict, litellm_model_name_model_info).copy(),
- cast(dict, model_info),
+ custom_model_info,
),
)
+ elif litellm_model_name_model_info is not None:
+ model_info = litellm_model_name_model_info
return model_info
diff --git a/tests/litellm/proxy/auth/test_auth_exception_handler.py b/tests/litellm/proxy/auth/test_auth_exception_handler.py
index 224bf24b57..3e780c6ee9 100644
--- a/tests/litellm/proxy/auth/test_auth_exception_handler.py
+++ b/tests/litellm/proxy/auth/test_auth_exception_handler.py
@@ -2,7 +2,7 @@ import asyncio
import json
import os
import sys
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, Request, status
@@ -110,3 +110,45 @@ async def test_handle_authentication_error_budget_exceeded():
)
assert exc_info.value.type == ProxyErrorTypes.budget_exceeded
+
+
+@pytest.mark.asyncio
+async def test_route_passed_to_post_call_failure_hook():
+ """
+ This route is used by proxy track_cost_callback's async_post_call_failure_hook to check if the route is an LLM route
+ """
+ handler = UserAPIKeyAuthExceptionHandler()
+
+ # Mock request and other dependencies
+ mock_request = MagicMock()
+ mock_request_data = {}
+ test_route = "/custom/route"
+ mock_span = None
+ mock_api_key = "test-key"
+
+ # Mock proxy_logging_obj.post_call_failure_hook
+ with patch(
+ "litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook",
+ new_callable=AsyncMock,
+ ) as mock_post_call_failure_hook:
+ # Test with DB connection error
+ with patch(
+ "litellm.proxy.proxy_server.general_settings",
+ {"allow_requests_on_db_unavailable": False},
+ ):
+ try:
+ await handler._handle_authentication_error(
+ PrismaError(),
+ mock_request,
+ mock_request_data,
+ test_route,
+ mock_span,
+ mock_api_key,
+ )
+ except Exception as e:
+ pass
+ asyncio.sleep(1)
+ # Verify post_call_failure_hook was called with the correct route
+ mock_post_call_failure_hook.assert_called_once()
+ call_args = mock_post_call_failure_hook.call_args[1]
+ assert call_args["user_api_key_dict"].request_route == test_route
diff --git a/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py b/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py
index 8850436329..cb6d90103f 100644
--- a/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py
+++ b/tests/litellm/proxy/hooks/test_proxy_track_cost_callback.py
@@ -81,3 +81,48 @@ async def test_async_post_call_failure_hook():
assert metadata["status"] == "failure"
assert "error_information" in metadata
assert metadata["original_key"] == "original_value"
+
+
+@pytest.mark.asyncio
+async def test_async_post_call_failure_hook_non_llm_route():
+ # Setup
+ logger = _ProxyDBLogger()
+
+ # Mock user_api_key_dict with a non-LLM route
+ user_api_key_dict = UserAPIKeyAuth(
+ api_key="test_api_key",
+ key_alias="test_alias",
+ user_email="test@example.com",
+ user_id="test_user_id",
+ team_id="test_team_id",
+ org_id="test_org_id",
+ team_alias="test_team_alias",
+ end_user_id="test_end_user_id",
+ request_route="/custom/route", # Non-LLM route
+ )
+
+ # Mock request data
+ request_data = {
+ "model": "gpt-4",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "metadata": {"original_key": "original_value"},
+ "proxy_server_request": {"request_id": "test_request_id"},
+ }
+
+ # Mock exception
+ original_exception = Exception("Test exception")
+
+ # Mock update_database function
+ with patch(
+ "litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter.update_database",
+ new_callable=AsyncMock,
+ ) as mock_update_database:
+ # Call the method
+ await logger.async_post_call_failure_hook(
+ request_data=request_data,
+ original_exception=original_exception,
+ user_api_key_dict=user_api_key_dict,
+ )
+
+ # Assert that update_database was NOT called for non-LLM routes
+ mock_update_database.assert_not_called()
diff --git a/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py
new file mode 100644
index 0000000000..51bbbb49c4
--- /dev/null
+++ b/tests/litellm/proxy/management_endpoints/test_key_management_endpoints.py
@@ -0,0 +1,48 @@
+import json
+import os
+import sys
+
+import pytest
+from fastapi.testclient import TestClient
+
+sys.path.insert(
+ 0, os.path.abspath("../../../..")
+) # Adds the parent directory to the system path
+
+from unittest.mock import AsyncMock, MagicMock
+
+from litellm.proxy.management_endpoints.key_management_endpoints import _list_key_helper
+from litellm.proxy.proxy_server import app
+
+client = TestClient(app)
+
+
+@pytest.mark.asyncio
+async def test_list_keys():
+ mock_prisma_client = AsyncMock()
+ mock_find_many = AsyncMock(return_value=[])
+ mock_prisma_client.db.litellm_verificationtoken.find_many = mock_find_many
+ args = {
+ "prisma_client": mock_prisma_client,
+ "page": 1,
+ "size": 50,
+ "user_id": "cda88cb4-cc2c-4e8c-b871-dc71ca111b00",
+ "team_id": None,
+ "organization_id": None,
+ "key_alias": None,
+ "exclude_team_id": None,
+ "return_full_object": True,
+ "admin_team_ids": ["28bd3181-02c5-48f2-b408-ce790fb3d5ba"],
+ }
+ try:
+ result = await _list_key_helper(**args)
+ except Exception as e:
+ print(f"error: {e}")
+
+ mock_find_many.assert_called_once()
+
+ where_condition = mock_find_many.call_args.kwargs["where"]
+ print(f"where_condition: {where_condition}")
+ assert json.dumps({"team_id": {"not": "litellm-dashboard"}}) in json.dumps(
+ where_condition
+ )
diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py
index 68a79f94a6..13eaeb09ab 100644
--- a/tests/local_testing/test_router.py
+++ b/tests/local_testing/test_router.py
@@ -2767,3 +2767,24 @@ def test_router_dynamic_credentials():
deployment = router.get_deployment(model_id=original_model_id)
assert deployment is not None
assert deployment.litellm_params.api_key == original_api_key
+
+
+def test_router_get_model_group_info():
+ router = Router(
+ model_list=[
+ {
+ "model_name": "gpt-3.5-turbo",
+ "litellm_params": {"model": "gpt-3.5-turbo"},
+ },
+ {
+ "model_name": "gpt-4",
+ "litellm_params": {"model": "gpt-4"},
+ },
+ ],
+ )
+
+ model_group_info = router.get_model_group_info(model_group="gpt-4")
+ assert model_group_info is not None
+ assert model_group_info.model_group == "gpt-4"
+ assert model_group_info.input_cost_per_token > 0
+ assert model_group_info.output_cost_per_token > 0
\ No newline at end of file
diff --git a/ui/litellm-dashboard/src/components/model_info_view.tsx b/ui/litellm-dashboard/src/components/model_info_view.tsx
index 6c626300a3..22c4fefc99 100644
--- a/ui/litellm-dashboard/src/components/model_info_view.tsx
+++ b/ui/litellm-dashboard/src/components/model_info_view.tsx
@@ -448,7 +448,7 @@ export default function ModelInfoView({
- RPM VVV(Requests per Minute)
+ RPM (Requests per Minute)
{isEditing ? (