Litellm UI qa 04 12 2025 p1 (#9955)

* fix(model_info_view.tsx): cleanup text

* fix(key_management_endpoints.py): fix filtering litellm-dashboard keys for internal users

* fix(proxy_track_cost_callback.py): prevent flooding spend logs with admin endpoint errors

* test: add unit testing for logic

* test(test_auth_exception_handler.py): add more unit testing

* fix(router.py): correctly handle retrieving model info on get_model_group_info

fixes issue where model hub was showing None prices

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2025-04-12 19:30:48 -07:00 committed by GitHub
parent f8d52e2db9
commit 00e49380df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 249 additions and 80 deletions

View file

@ -13,7 +13,7 @@ import os
import httpx
def get_model_cost_map(url: str):
def get_model_cost_map(url: str) -> dict:
if (
os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False)
or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True"

File diff suppressed because one or more lines are too long

View file

@ -644,9 +644,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
allowed_cache_controls: Optional[list] = []
config: Optional[dict] = {}
permissions: Optional[dict] = {}
model_max_budget: Optional[dict] = (
{}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_max_budget: Optional[
dict
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_config = ConfigDict(protected_namespaces=())
model_rpm_limit: Optional[dict] = None
@ -902,12 +902,12 @@ class NewCustomerRequest(BudgetNewRequest):
alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
@model_validator(mode="before")
@classmethod
@ -929,12 +929,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
@ -1070,9 +1070,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
class AddTeamCallback(LiteLLMPydanticObjectBase):
callback_name: str
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
"success_and_failure"
)
callback_type: Optional[
Literal["success", "failure", "success_and_failure"]
] = "success_and_failure"
callback_vars: Dict[str, str]
@model_validator(mode="before")
@ -1329,9 +1329,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
stored_in_db: Optional[bool]
field_default_value: Any
premium_field: bool = False
nested_fields: Optional[List[FieldDetail]] = (
None # For nested dictionary or Pydantic fields
)
nested_fields: Optional[
List[FieldDetail]
] = None # For nested dictionary or Pydantic fields
class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
@ -1558,6 +1558,7 @@ class UserAPIKeyAuth(
user_tpm_limit: Optional[int] = None
user_rpm_limit: Optional[int] = None
user_email: Optional[str] = None
request_route: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -1597,9 +1598,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
budget_id: Optional[str] = None
created_at: datetime
updated_at: datetime
user: Optional[Any] = (
None # You might want to replace 'Any' with a more specific type if available
)
user: Optional[
Any
] = None # You might want to replace 'Any' with a more specific type if available
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
model_config = ConfigDict(protected_namespaces=())
@ -2345,9 +2346,9 @@ class TeamModelDeleteRequest(BaseModel):
# Organization Member Requests
class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str
max_budget_in_organization: Optional[float] = (
None # Users max budget within the organization
)
max_budget_in_organization: Optional[
float
] = None # Users max budget within the organization
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
@ -2536,9 +2537,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
Maps provider names to their budget configs.
"""
providers: Dict[str, ProviderBudgetResponseObject] = (
{}
) # Dictionary mapping provider names to their budget configurations
providers: Dict[
str, ProviderBudgetResponseObject
] = {} # Dictionary mapping provider names to their budget configurations
class ProxyStateVariables(TypedDict):
@ -2666,9 +2667,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
enforce_rbac: bool = False
roles_jwt_field: Optional[str] = None # v2 on role mappings
role_mappings: Optional[List[RoleMapping]] = None
object_id_jwt_field: Optional[str] = (
None # can be either user / team, inferred from the role mapping
)
object_id_jwt_field: Optional[
str
] = None # can be either user / team, inferred from the role mapping
scope_mappings: Optional[List[ScopeMapping]] = None
enforce_scope_based_access: bool = False
enforce_team_based_model_access: bool = False

View file

@ -68,6 +68,7 @@ class UserAPIKeyAuthExceptionHandler:
key_name="failed-to-connect-to-db",
token="failed-to-connect-to-db",
user_id=litellm_proxy_admin_name,
request_route=route,
)
else:
# raise the exception to the caller
@ -87,6 +88,7 @@ class UserAPIKeyAuthExceptionHandler:
user_api_key_dict = UserAPIKeyAuth(
parent_otel_span=parent_otel_span,
api_key=api_key,
request_route=route,
)
asyncio.create_task(
proxy_logging_obj.post_call_failure_hook(

View file

@ -1023,6 +1023,7 @@ async def user_api_key_auth(
"""
request_data = await _read_request_body(request=request)
route: str = get_request_route(request=request)
user_api_key_auth_obj = await _user_api_key_auth_builder(
request=request,
@ -1038,6 +1039,8 @@ async def user_api_key_auth(
if end_user_id is not None:
user_api_key_auth_obj.end_user_id = end_user_id
user_api_key_auth_obj.request_route = route
return user_api_key_auth_obj

View file

@ -13,6 +13,7 @@ from litellm.litellm_core_utils.core_helpers import (
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_checks import log_db_metrics
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import ProxyUpdateSpend
from litellm.types.utils import (
StandardLoggingPayload,
@ -33,8 +34,13 @@ class _ProxyDBLogger(CustomLogger):
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
):
request_route = user_api_key_dict.request_route
if _ProxyDBLogger._should_track_errors_in_db() is False:
return
elif request_route is not None and not RouteChecks.is_llm_api_route(
route=request_route
):
return
from litellm.proxy.proxy_server import proxy_logging_obj

View file

@ -577,9 +577,9 @@ async def generate_key_fn( # noqa: PLR0915
request_type="key", **data_json, table_name="key"
)
response["soft_budget"] = (
data.soft_budget
) # include the user-input soft budget in the response
response[
"soft_budget"
] = data.soft_budget # include the user-input soft budget in the response
response = GenerateKeyResponse(**response)
@ -1467,10 +1467,10 @@ async def delete_verification_tokens(
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
_keys_being_deleted: List[
LiteLLM_VerificationToken
] = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
# Assuming 'db' is your Prisma Client instance
@ -1572,9 +1572,9 @@ async def _rotate_master_key(
from litellm.proxy.proxy_server import proxy_config
try:
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
models: Optional[
List
] = await prisma_client.db.litellm_proxymodeltable.find_many()
except Exception:
models = None
# 2. process model table
@ -1861,11 +1861,11 @@ async def validate_key_list_check(
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_db_obj: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
complete_user_info_db_obj: Optional[
BaseModel
] = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
if complete_user_info_db_obj is None:
@ -1926,10 +1926,10 @@ async def get_admin_team_ids(
if complete_user_info is None:
return []
# Get all teams that user is an admin of
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
teams: Optional[
List[BaseModel]
] = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
if teams is None:
return []
@ -2080,7 +2080,6 @@ async def _list_key_helper(
"total_pages": int,
}
"""
# Prepare filter conditions
where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {}
where.update(_get_condition_to_filter_out_ui_session_tokens())
@ -2110,7 +2109,7 @@ async def _list_key_helper(
# Combine conditions with OR if we have multiple conditions
if len(or_conditions) > 1:
where["OR"] = or_conditions
where = {"AND": [where, {"OR": or_conditions}]}
elif len(or_conditions) == 1:
where.update(or_conditions[0])

View file

@ -339,9 +339,9 @@ class Router:
) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
### CACHING ###
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
"local" # default to an in-memory cache
)
cache_type: Literal[
"local", "redis", "redis-semantic", "s3", "disk"
] = "local" # default to an in-memory cache
redis_cache = None
cache_config: Dict[str, Any] = {}
@ -562,9 +562,9 @@ class Router:
)
)
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)
self.model_group_retry_policy: Optional[
Dict[str, RetryPolicy]
] = model_group_retry_policy
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
if allowed_fails_policy is not None:
@ -1105,9 +1105,9 @@ class Router:
"""
Adds default litellm params to kwargs, if set.
"""
self.default_litellm_params[metadata_variable_name] = (
self.default_litellm_params.pop("metadata", {})
)
self.default_litellm_params[
metadata_variable_name
] = self.default_litellm_params.pop("metadata", {})
for k, v in self.default_litellm_params.items():
if (
k not in kwargs and v is not None
@ -3243,11 +3243,11 @@ class Router:
if isinstance(e, litellm.ContextWindowExceededError):
if context_window_fallbacks is not None:
fallback_model_group: Optional[List[str]] = (
self._get_fallback_model_group_from_fallbacks(
fallbacks=context_window_fallbacks,
model_group=model_group,
)
fallback_model_group: Optional[
List[str]
] = self._get_fallback_model_group_from_fallbacks(
fallbacks=context_window_fallbacks,
model_group=model_group,
)
if fallback_model_group is None:
raise original_exception
@ -3279,11 +3279,11 @@ class Router:
e.message += "\n{}".format(error_message)
elif isinstance(e, litellm.ContentPolicyViolationError):
if content_policy_fallbacks is not None:
fallback_model_group: Optional[List[str]] = (
self._get_fallback_model_group_from_fallbacks(
fallbacks=content_policy_fallbacks,
model_group=model_group,
)
fallback_model_group: Optional[
List[str]
] = self._get_fallback_model_group_from_fallbacks(
fallbacks=content_policy_fallbacks,
model_group=model_group,
)
if fallback_model_group is None:
raise original_exception
@ -4853,10 +4853,11 @@ class Router:
from litellm.utils import _update_dictionary
model_info: Optional[ModelInfo] = None
custom_model_info: Optional[dict] = None
litellm_model_name_model_info: Optional[ModelInfo] = None
try:
model_info = litellm.get_model_info(model=model_id)
custom_model_info = litellm.model_cost.get(model_id)
except Exception:
pass
@ -4865,14 +4866,16 @@ class Router:
except Exception:
pass
if model_info is not None and litellm_model_name_model_info is not None:
if custom_model_info is not None and litellm_model_name_model_info is not None:
model_info = cast(
ModelInfo,
_update_dictionary(
cast(dict, litellm_model_name_model_info).copy(),
cast(dict, model_info),
custom_model_info,
),
)
elif litellm_model_name_model_info is not None:
model_info = litellm_model_name_model_info
return model_info

View file

@ -2,7 +2,7 @@ import asyncio
import json
import os
import sys
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, Request, status
@ -110,3 +110,45 @@ async def test_handle_authentication_error_budget_exceeded():
)
assert exc_info.value.type == ProxyErrorTypes.budget_exceeded
@pytest.mark.asyncio
async def test_route_passed_to_post_call_failure_hook():
"""
This route is used by proxy track_cost_callback's async_post_call_failure_hook to check if the route is an LLM route
"""
handler = UserAPIKeyAuthExceptionHandler()
# Mock request and other dependencies
mock_request = MagicMock()
mock_request_data = {}
test_route = "/custom/route"
mock_span = None
mock_api_key = "test-key"
# Mock proxy_logging_obj.post_call_failure_hook
with patch(
"litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook",
new_callable=AsyncMock,
) as mock_post_call_failure_hook:
# Test with DB connection error
with patch(
"litellm.proxy.proxy_server.general_settings",
{"allow_requests_on_db_unavailable": False},
):
try:
await handler._handle_authentication_error(
PrismaError(),
mock_request,
mock_request_data,
test_route,
mock_span,
mock_api_key,
)
except Exception as e:
pass
asyncio.sleep(1)
# Verify post_call_failure_hook was called with the correct route
mock_post_call_failure_hook.assert_called_once()
call_args = mock_post_call_failure_hook.call_args[1]
assert call_args["user_api_key_dict"].request_route == test_route

View file

@ -81,3 +81,48 @@ async def test_async_post_call_failure_hook():
assert metadata["status"] == "failure"
assert "error_information" in metadata
assert metadata["original_key"] == "original_value"
@pytest.mark.asyncio
async def test_async_post_call_failure_hook_non_llm_route():
# Setup
logger = _ProxyDBLogger()
# Mock user_api_key_dict with a non-LLM route
user_api_key_dict = UserAPIKeyAuth(
api_key="test_api_key",
key_alias="test_alias",
user_email="test@example.com",
user_id="test_user_id",
team_id="test_team_id",
org_id="test_org_id",
team_alias="test_team_alias",
end_user_id="test_end_user_id",
request_route="/custom/route", # Non-LLM route
)
# Mock request data
request_data = {
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"metadata": {"original_key": "original_value"},
"proxy_server_request": {"request_id": "test_request_id"},
}
# Mock exception
original_exception = Exception("Test exception")
# Mock update_database function
with patch(
"litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter.update_database",
new_callable=AsyncMock,
) as mock_update_database:
# Call the method
await logger.async_post_call_failure_hook(
request_data=request_data,
original_exception=original_exception,
user_api_key_dict=user_api_key_dict,
)
# Assert that update_database was NOT called for non-LLM routes
mock_update_database.assert_not_called()

View file

@ -0,0 +1,48 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock
from litellm.proxy.management_endpoints.key_management_endpoints import _list_key_helper
from litellm.proxy.proxy_server import app
client = TestClient(app)
@pytest.mark.asyncio
async def test_list_keys():
mock_prisma_client = AsyncMock()
mock_find_many = AsyncMock(return_value=[])
mock_prisma_client.db.litellm_verificationtoken.find_many = mock_find_many
args = {
"prisma_client": mock_prisma_client,
"page": 1,
"size": 50,
"user_id": "cda88cb4-cc2c-4e8c-b871-dc71ca111b00",
"team_id": None,
"organization_id": None,
"key_alias": None,
"exclude_team_id": None,
"return_full_object": True,
"admin_team_ids": ["28bd3181-02c5-48f2-b408-ce790fb3d5ba"],
}
try:
result = await _list_key_helper(**args)
except Exception as e:
print(f"error: {e}")
mock_find_many.assert_called_once()
where_condition = mock_find_many.call_args.kwargs["where"]
print(f"where_condition: {where_condition}")
assert json.dumps({"team_id": {"not": "litellm-dashboard"}}) in json.dumps(
where_condition
)

View file

@ -2767,3 +2767,24 @@ def test_router_dynamic_credentials():
deployment = router.get_deployment(model_id=original_model_id)
assert deployment is not None
assert deployment.litellm_params.api_key == original_api_key
def test_router_get_model_group_info():
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4"},
},
],
)
model_group_info = router.get_model_group_info(model_group="gpt-4")
assert model_group_info is not None
assert model_group_info.model_group == "gpt-4"
assert model_group_info.input_cost_per_token > 0
assert model_group_info.output_cost_per_token > 0

View file

@ -448,7 +448,7 @@ export default function ModelInfoView({
</div>
<div>
<Text className="font-medium">RPM VVV(Requests per Minute)</Text>
<Text className="font-medium">RPM (Requests per Minute)</Text>
{isEditing ? (
<Form.Item name="rpm" className="mb-0">
<NumericalInput placeholder="Enter RPM" />