mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Litellm UI qa 04 12 2025 p1 (#9955)
* fix(model_info_view.tsx): cleanup text * fix(key_management_endpoints.py): fix filtering litellm-dashboard keys for internal users * fix(proxy_track_cost_callback.py): prevent flooding spend logs with admin endpoint errors * test: add unit testing for logic * test(test_auth_exception_handler.py): add more unit testing * fix(router.py): correctly handle retrieving model info on get_model_group_info fixes issue where model hub was showing None prices * fix: fix linting errors
This commit is contained in:
parent
f8d52e2db9
commit
00e49380df
13 changed files with 249 additions and 80 deletions
|
@ -13,7 +13,7 @@ import os
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
def get_model_cost_map(url: str):
|
def get_model_cost_map(url: str) -> dict:
|
||||||
if (
|
if (
|
||||||
os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False)
|
os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False)
|
||||||
or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True"
|
or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True"
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -644,9 +644,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
|
||||||
allowed_cache_controls: Optional[list] = []
|
allowed_cache_controls: Optional[list] = []
|
||||||
config: Optional[dict] = {}
|
config: Optional[dict] = {}
|
||||||
permissions: Optional[dict] = {}
|
permissions: Optional[dict] = {}
|
||||||
model_max_budget: Optional[dict] = (
|
model_max_budget: Optional[
|
||||||
{}
|
dict
|
||||||
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
|
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
model_rpm_limit: Optional[dict] = None
|
model_rpm_limit: Optional[dict] = None
|
||||||
|
@ -902,12 +902,12 @@ class NewCustomerRequest(BudgetNewRequest):
|
||||||
alias: Optional[str] = None # human-friendly alias
|
alias: Optional[str] = None # human-friendly alias
|
||||||
blocked: bool = False # allow/disallow requests for this end-user
|
blocked: bool = False # allow/disallow requests for this end-user
|
||||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||||
allowed_model_region: Optional[AllowedModelRegion] = (
|
allowed_model_region: Optional[
|
||||||
None # require all user requests to use models in this specific region
|
AllowedModelRegion
|
||||||
)
|
] = None # require all user requests to use models in this specific region
|
||||||
default_model: Optional[str] = (
|
default_model: Optional[
|
||||||
None # if no equivalent model in allowed region - default all requests to this model
|
str
|
||||||
)
|
] = None # if no equivalent model in allowed region - default all requests to this model
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -929,12 +929,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
|
||||||
blocked: bool = False # allow/disallow requests for this end-user
|
blocked: bool = False # allow/disallow requests for this end-user
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||||
allowed_model_region: Optional[AllowedModelRegion] = (
|
allowed_model_region: Optional[
|
||||||
None # require all user requests to use models in this specific region
|
AllowedModelRegion
|
||||||
)
|
] = None # require all user requests to use models in this specific region
|
||||||
default_model: Optional[str] = (
|
default_model: Optional[
|
||||||
None # if no equivalent model in allowed region - default all requests to this model
|
str
|
||||||
)
|
] = None # if no equivalent model in allowed region - default all requests to this model
|
||||||
|
|
||||||
|
|
||||||
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
|
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
|
||||||
|
@ -1070,9 +1070,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
|
||||||
|
|
||||||
class AddTeamCallback(LiteLLMPydanticObjectBase):
|
class AddTeamCallback(LiteLLMPydanticObjectBase):
|
||||||
callback_name: str
|
callback_name: str
|
||||||
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
|
callback_type: Optional[
|
||||||
"success_and_failure"
|
Literal["success", "failure", "success_and_failure"]
|
||||||
)
|
] = "success_and_failure"
|
||||||
callback_vars: Dict[str, str]
|
callback_vars: Dict[str, str]
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@ -1329,9 +1329,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
|
||||||
stored_in_db: Optional[bool]
|
stored_in_db: Optional[bool]
|
||||||
field_default_value: Any
|
field_default_value: Any
|
||||||
premium_field: bool = False
|
premium_field: bool = False
|
||||||
nested_fields: Optional[List[FieldDetail]] = (
|
nested_fields: Optional[
|
||||||
None # For nested dictionary or Pydantic fields
|
List[FieldDetail]
|
||||||
)
|
] = None # For nested dictionary or Pydantic fields
|
||||||
|
|
||||||
|
|
||||||
class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
|
class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
|
||||||
|
@ -1558,6 +1558,7 @@ class UserAPIKeyAuth(
|
||||||
user_tpm_limit: Optional[int] = None
|
user_tpm_limit: Optional[int] = None
|
||||||
user_rpm_limit: Optional[int] = None
|
user_rpm_limit: Optional[int] = None
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
|
request_route: Optional[str] = None
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
@ -1597,9 +1598,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
|
||||||
budget_id: Optional[str] = None
|
budget_id: Optional[str] = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
user: Optional[Any] = (
|
user: Optional[
|
||||||
None # You might want to replace 'Any' with a more specific type if available
|
Any
|
||||||
)
|
] = None # You might want to replace 'Any' with a more specific type if available
|
||||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
@ -2345,9 +2346,9 @@ class TeamModelDeleteRequest(BaseModel):
|
||||||
# Organization Member Requests
|
# Organization Member Requests
|
||||||
class OrganizationMemberAddRequest(OrgMemberAddRequest):
|
class OrganizationMemberAddRequest(OrgMemberAddRequest):
|
||||||
organization_id: str
|
organization_id: str
|
||||||
max_budget_in_organization: Optional[float] = (
|
max_budget_in_organization: Optional[
|
||||||
None # Users max budget within the organization
|
float
|
||||||
)
|
] = None # Users max budget within the organization
|
||||||
|
|
||||||
|
|
||||||
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
|
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
|
||||||
|
@ -2536,9 +2537,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
|
||||||
Maps provider names to their budget configs.
|
Maps provider names to their budget configs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
providers: Dict[str, ProviderBudgetResponseObject] = (
|
providers: Dict[
|
||||||
{}
|
str, ProviderBudgetResponseObject
|
||||||
) # Dictionary mapping provider names to their budget configurations
|
] = {} # Dictionary mapping provider names to their budget configurations
|
||||||
|
|
||||||
|
|
||||||
class ProxyStateVariables(TypedDict):
|
class ProxyStateVariables(TypedDict):
|
||||||
|
@ -2666,9 +2667,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||||
enforce_rbac: bool = False
|
enforce_rbac: bool = False
|
||||||
roles_jwt_field: Optional[str] = None # v2 on role mappings
|
roles_jwt_field: Optional[str] = None # v2 on role mappings
|
||||||
role_mappings: Optional[List[RoleMapping]] = None
|
role_mappings: Optional[List[RoleMapping]] = None
|
||||||
object_id_jwt_field: Optional[str] = (
|
object_id_jwt_field: Optional[
|
||||||
None # can be either user / team, inferred from the role mapping
|
str
|
||||||
)
|
] = None # can be either user / team, inferred from the role mapping
|
||||||
scope_mappings: Optional[List[ScopeMapping]] = None
|
scope_mappings: Optional[List[ScopeMapping]] = None
|
||||||
enforce_scope_based_access: bool = False
|
enforce_scope_based_access: bool = False
|
||||||
enforce_team_based_model_access: bool = False
|
enforce_team_based_model_access: bool = False
|
||||||
|
|
|
@ -68,6 +68,7 @@ class UserAPIKeyAuthExceptionHandler:
|
||||||
key_name="failed-to-connect-to-db",
|
key_name="failed-to-connect-to-db",
|
||||||
token="failed-to-connect-to-db",
|
token="failed-to-connect-to-db",
|
||||||
user_id=litellm_proxy_admin_name,
|
user_id=litellm_proxy_admin_name,
|
||||||
|
request_route=route,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# raise the exception to the caller
|
# raise the exception to the caller
|
||||||
|
@ -87,6 +88,7 @@ class UserAPIKeyAuthExceptionHandler:
|
||||||
user_api_key_dict = UserAPIKeyAuth(
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
request_route=route,
|
||||||
)
|
)
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
proxy_logging_obj.post_call_failure_hook(
|
proxy_logging_obj.post_call_failure_hook(
|
||||||
|
|
|
@ -1023,6 +1023,7 @@ async def user_api_key_auth(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
request_data = await _read_request_body(request=request)
|
request_data = await _read_request_body(request=request)
|
||||||
|
route: str = get_request_route(request=request)
|
||||||
|
|
||||||
user_api_key_auth_obj = await _user_api_key_auth_builder(
|
user_api_key_auth_obj = await _user_api_key_auth_builder(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -1038,6 +1039,8 @@ async def user_api_key_auth(
|
||||||
if end_user_id is not None:
|
if end_user_id is not None:
|
||||||
user_api_key_auth_obj.end_user_id = end_user_id
|
user_api_key_auth_obj.end_user_id = end_user_id
|
||||||
|
|
||||||
|
user_api_key_auth_obj.request_route = route
|
||||||
|
|
||||||
return user_api_key_auth_obj
|
return user_api_key_auth_obj
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from litellm.litellm_core_utils.core_helpers import (
|
||||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.proxy.auth.auth_checks import log_db_metrics
|
from litellm.proxy.auth.auth_checks import log_db_metrics
|
||||||
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
from litellm.proxy.utils import ProxyUpdateSpend
|
from litellm.proxy.utils import ProxyUpdateSpend
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
StandardLoggingPayload,
|
StandardLoggingPayload,
|
||||||
|
@ -33,8 +34,13 @@ class _ProxyDBLogger(CustomLogger):
|
||||||
original_exception: Exception,
|
original_exception: Exception,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
):
|
):
|
||||||
|
request_route = user_api_key_dict.request_route
|
||||||
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
||||||
return
|
return
|
||||||
|
elif request_route is not None and not RouteChecks.is_llm_api_route(
|
||||||
|
route=request_route
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||||
|
|
||||||
|
|
|
@ -577,9 +577,9 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
request_type="key", **data_json, table_name="key"
|
request_type="key", **data_json, table_name="key"
|
||||||
)
|
)
|
||||||
|
|
||||||
response["soft_budget"] = (
|
response[
|
||||||
data.soft_budget
|
"soft_budget"
|
||||||
) # include the user-input soft budget in the response
|
] = data.soft_budget # include the user-input soft budget in the response
|
||||||
|
|
||||||
response = GenerateKeyResponse(**response)
|
response = GenerateKeyResponse(**response)
|
||||||
|
|
||||||
|
@ -1467,10 +1467,10 @@ async def delete_verification_tokens(
|
||||||
try:
|
try:
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
tokens = [_hash_token_if_needed(token=key) for key in tokens]
|
tokens = [_hash_token_if_needed(token=key) for key in tokens]
|
||||||
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
|
_keys_being_deleted: List[
|
||||||
await prisma_client.db.litellm_verificationtoken.find_many(
|
LiteLLM_VerificationToken
|
||||||
where={"token": {"in": tokens}}
|
] = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||||
)
|
where={"token": {"in": tokens}}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assuming 'db' is your Prisma Client instance
|
# Assuming 'db' is your Prisma Client instance
|
||||||
|
@ -1572,9 +1572,9 @@ async def _rotate_master_key(
|
||||||
from litellm.proxy.proxy_server import proxy_config
|
from litellm.proxy.proxy_server import proxy_config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
models: Optional[List] = (
|
models: Optional[
|
||||||
await prisma_client.db.litellm_proxymodeltable.find_many()
|
List
|
||||||
)
|
] = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||||
except Exception:
|
except Exception:
|
||||||
models = None
|
models = None
|
||||||
# 2. process model table
|
# 2. process model table
|
||||||
|
@ -1861,11 +1861,11 @@ async def validate_key_list_check(
|
||||||
param="user_id",
|
param="user_id",
|
||||||
code=status.HTTP_403_FORBIDDEN,
|
code=status.HTTP_403_FORBIDDEN,
|
||||||
)
|
)
|
||||||
complete_user_info_db_obj: Optional[BaseModel] = (
|
complete_user_info_db_obj: Optional[
|
||||||
await prisma_client.db.litellm_usertable.find_unique(
|
BaseModel
|
||||||
where={"user_id": user_api_key_dict.user_id},
|
] = await prisma_client.db.litellm_usertable.find_unique(
|
||||||
include={"organization_memberships": True},
|
where={"user_id": user_api_key_dict.user_id},
|
||||||
)
|
include={"organization_memberships": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
if complete_user_info_db_obj is None:
|
if complete_user_info_db_obj is None:
|
||||||
|
@ -1926,10 +1926,10 @@ async def get_admin_team_ids(
|
||||||
if complete_user_info is None:
|
if complete_user_info is None:
|
||||||
return []
|
return []
|
||||||
# Get all teams that user is an admin of
|
# Get all teams that user is an admin of
|
||||||
teams: Optional[List[BaseModel]] = (
|
teams: Optional[
|
||||||
await prisma_client.db.litellm_teamtable.find_many(
|
List[BaseModel]
|
||||||
where={"team_id": {"in": complete_user_info.teams}}
|
] = await prisma_client.db.litellm_teamtable.find_many(
|
||||||
)
|
where={"team_id": {"in": complete_user_info.teams}}
|
||||||
)
|
)
|
||||||
if teams is None:
|
if teams is None:
|
||||||
return []
|
return []
|
||||||
|
@ -2080,7 +2080,6 @@ async def _list_key_helper(
|
||||||
"total_pages": int,
|
"total_pages": int,
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Prepare filter conditions
|
# Prepare filter conditions
|
||||||
where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {}
|
where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {}
|
||||||
where.update(_get_condition_to_filter_out_ui_session_tokens())
|
where.update(_get_condition_to_filter_out_ui_session_tokens())
|
||||||
|
@ -2110,7 +2109,7 @@ async def _list_key_helper(
|
||||||
|
|
||||||
# Combine conditions with OR if we have multiple conditions
|
# Combine conditions with OR if we have multiple conditions
|
||||||
if len(or_conditions) > 1:
|
if len(or_conditions) > 1:
|
||||||
where["OR"] = or_conditions
|
where = {"AND": [where, {"OR": or_conditions}]}
|
||||||
elif len(or_conditions) == 1:
|
elif len(or_conditions) == 1:
|
||||||
where.update(or_conditions[0])
|
where.update(or_conditions[0])
|
||||||
|
|
||||||
|
|
|
@ -339,9 +339,9 @@ class Router:
|
||||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||||
self.deployment_latency_map = {}
|
self.deployment_latency_map = {}
|
||||||
### CACHING ###
|
### CACHING ###
|
||||||
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
|
cache_type: Literal[
|
||||||
"local" # default to an in-memory cache
|
"local", "redis", "redis-semantic", "s3", "disk"
|
||||||
)
|
] = "local" # default to an in-memory cache
|
||||||
redis_cache = None
|
redis_cache = None
|
||||||
cache_config: Dict[str, Any] = {}
|
cache_config: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
@ -562,9 +562,9 @@ class Router:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
self.model_group_retry_policy: Optional[
|
||||||
model_group_retry_policy
|
Dict[str, RetryPolicy]
|
||||||
)
|
] = model_group_retry_policy
|
||||||
|
|
||||||
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
|
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
|
||||||
if allowed_fails_policy is not None:
|
if allowed_fails_policy is not None:
|
||||||
|
@ -1105,9 +1105,9 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Adds default litellm params to kwargs, if set.
|
Adds default litellm params to kwargs, if set.
|
||||||
"""
|
"""
|
||||||
self.default_litellm_params[metadata_variable_name] = (
|
self.default_litellm_params[
|
||||||
self.default_litellm_params.pop("metadata", {})
|
metadata_variable_name
|
||||||
)
|
] = self.default_litellm_params.pop("metadata", {})
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if (
|
if (
|
||||||
k not in kwargs and v is not None
|
k not in kwargs and v is not None
|
||||||
|
@ -3243,11 +3243,11 @@ class Router:
|
||||||
|
|
||||||
if isinstance(e, litellm.ContextWindowExceededError):
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
if context_window_fallbacks is not None:
|
if context_window_fallbacks is not None:
|
||||||
fallback_model_group: Optional[List[str]] = (
|
fallback_model_group: Optional[
|
||||||
self._get_fallback_model_group_from_fallbacks(
|
List[str]
|
||||||
fallbacks=context_window_fallbacks,
|
] = self._get_fallback_model_group_from_fallbacks(
|
||||||
model_group=model_group,
|
fallbacks=context_window_fallbacks,
|
||||||
)
|
model_group=model_group,
|
||||||
)
|
)
|
||||||
if fallback_model_group is None:
|
if fallback_model_group is None:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
@ -3279,11 +3279,11 @@ class Router:
|
||||||
e.message += "\n{}".format(error_message)
|
e.message += "\n{}".format(error_message)
|
||||||
elif isinstance(e, litellm.ContentPolicyViolationError):
|
elif isinstance(e, litellm.ContentPolicyViolationError):
|
||||||
if content_policy_fallbacks is not None:
|
if content_policy_fallbacks is not None:
|
||||||
fallback_model_group: Optional[List[str]] = (
|
fallback_model_group: Optional[
|
||||||
self._get_fallback_model_group_from_fallbacks(
|
List[str]
|
||||||
fallbacks=content_policy_fallbacks,
|
] = self._get_fallback_model_group_from_fallbacks(
|
||||||
model_group=model_group,
|
fallbacks=content_policy_fallbacks,
|
||||||
)
|
model_group=model_group,
|
||||||
)
|
)
|
||||||
if fallback_model_group is None:
|
if fallback_model_group is None:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
@ -4853,10 +4853,11 @@ class Router:
|
||||||
from litellm.utils import _update_dictionary
|
from litellm.utils import _update_dictionary
|
||||||
|
|
||||||
model_info: Optional[ModelInfo] = None
|
model_info: Optional[ModelInfo] = None
|
||||||
|
custom_model_info: Optional[dict] = None
|
||||||
litellm_model_name_model_info: Optional[ModelInfo] = None
|
litellm_model_name_model_info: Optional[ModelInfo] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_info = litellm.get_model_info(model=model_id)
|
custom_model_info = litellm.model_cost.get(model_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -4865,14 +4866,16 @@ class Router:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if model_info is not None and litellm_model_name_model_info is not None:
|
if custom_model_info is not None and litellm_model_name_model_info is not None:
|
||||||
model_info = cast(
|
model_info = cast(
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
_update_dictionary(
|
_update_dictionary(
|
||||||
cast(dict, litellm_model_name_model_info).copy(),
|
cast(dict, litellm_model_name_model_info).copy(),
|
||||||
cast(dict, model_info),
|
custom_model_info,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif litellm_model_name_model_info is not None:
|
||||||
|
model_info = litellm_model_name_model_info
|
||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import HTTPException, Request, status
|
from fastapi import HTTPException, Request, status
|
||||||
|
@ -110,3 +110,45 @@ async def test_handle_authentication_error_budget_exceeded():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert exc_info.value.type == ProxyErrorTypes.budget_exceeded
|
assert exc_info.value.type == ProxyErrorTypes.budget_exceeded
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_route_passed_to_post_call_failure_hook():
|
||||||
|
"""
|
||||||
|
This route is used by proxy track_cost_callback's async_post_call_failure_hook to check if the route is an LLM route
|
||||||
|
"""
|
||||||
|
handler = UserAPIKeyAuthExceptionHandler()
|
||||||
|
|
||||||
|
# Mock request and other dependencies
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request_data = {}
|
||||||
|
test_route = "/custom/route"
|
||||||
|
mock_span = None
|
||||||
|
mock_api_key = "test-key"
|
||||||
|
|
||||||
|
# Mock proxy_logging_obj.post_call_failure_hook
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_post_call_failure_hook:
|
||||||
|
# Test with DB connection error
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.proxy_server.general_settings",
|
||||||
|
{"allow_requests_on_db_unavailable": False},
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await handler._handle_authentication_error(
|
||||||
|
PrismaError(),
|
||||||
|
mock_request,
|
||||||
|
mock_request_data,
|
||||||
|
test_route,
|
||||||
|
mock_span,
|
||||||
|
mock_api_key,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
asyncio.sleep(1)
|
||||||
|
# Verify post_call_failure_hook was called with the correct route
|
||||||
|
mock_post_call_failure_hook.assert_called_once()
|
||||||
|
call_args = mock_post_call_failure_hook.call_args[1]
|
||||||
|
assert call_args["user_api_key_dict"].request_route == test_route
|
||||||
|
|
|
@ -81,3 +81,48 @@ async def test_async_post_call_failure_hook():
|
||||||
assert metadata["status"] == "failure"
|
assert metadata["status"] == "failure"
|
||||||
assert "error_information" in metadata
|
assert "error_information" in metadata
|
||||||
assert metadata["original_key"] == "original_value"
|
assert metadata["original_key"] == "original_value"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_post_call_failure_hook_non_llm_route():
|
||||||
|
# Setup
|
||||||
|
logger = _ProxyDBLogger()
|
||||||
|
|
||||||
|
# Mock user_api_key_dict with a non-LLM route
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key="test_api_key",
|
||||||
|
key_alias="test_alias",
|
||||||
|
user_email="test@example.com",
|
||||||
|
user_id="test_user_id",
|
||||||
|
team_id="test_team_id",
|
||||||
|
org_id="test_org_id",
|
||||||
|
team_alias="test_team_alias",
|
||||||
|
end_user_id="test_end_user_id",
|
||||||
|
request_route="/custom/route", # Non-LLM route
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock request data
|
||||||
|
request_data = {
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"metadata": {"original_key": "original_value"},
|
||||||
|
"proxy_server_request": {"request_id": "test_request_id"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock exception
|
||||||
|
original_exception = Exception("Test exception")
|
||||||
|
|
||||||
|
# Mock update_database function
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.db.db_spend_update_writer.DBSpendUpdateWriter.update_database",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_update_database:
|
||||||
|
# Call the method
|
||||||
|
await logger.async_post_call_failure_hook(
|
||||||
|
request_data=request_data,
|
||||||
|
original_exception=original_exception,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that update_database was NOT called for non-LLM routes
|
||||||
|
mock_update_database.assert_not_called()
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from litellm.proxy.management_endpoints.key_management_endpoints import _list_key_helper
|
||||||
|
from litellm.proxy.proxy_server import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys():
|
||||||
|
mock_prisma_client = AsyncMock()
|
||||||
|
mock_find_many = AsyncMock(return_value=[])
|
||||||
|
mock_prisma_client.db.litellm_verificationtoken.find_many = mock_find_many
|
||||||
|
args = {
|
||||||
|
"prisma_client": mock_prisma_client,
|
||||||
|
"page": 1,
|
||||||
|
"size": 50,
|
||||||
|
"user_id": "cda88cb4-cc2c-4e8c-b871-dc71ca111b00",
|
||||||
|
"team_id": None,
|
||||||
|
"organization_id": None,
|
||||||
|
"key_alias": None,
|
||||||
|
"exclude_team_id": None,
|
||||||
|
"return_full_object": True,
|
||||||
|
"admin_team_ids": ["28bd3181-02c5-48f2-b408-ce790fb3d5ba"],
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
result = await _list_key_helper(**args)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error: {e}")
|
||||||
|
|
||||||
|
mock_find_many.assert_called_once()
|
||||||
|
|
||||||
|
where_condition = mock_find_many.call_args.kwargs["where"]
|
||||||
|
print(f"where_condition: {where_condition}")
|
||||||
|
assert json.dumps({"team_id": {"not": "litellm-dashboard"}}) in json.dumps(
|
||||||
|
where_condition
|
||||||
|
)
|
|
@ -2767,3 +2767,24 @@ def test_router_dynamic_credentials():
|
||||||
deployment = router.get_deployment(model_id=original_model_id)
|
deployment = router.get_deployment(model_id=original_model_id)
|
||||||
assert deployment is not None
|
assert deployment is not None
|
||||||
assert deployment.litellm_params.api_key == original_api_key
|
assert deployment.litellm_params.api_key == original_api_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_get_model_group_info():
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {"model": "gpt-4"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
model_group_info = router.get_model_group_info(model_group="gpt-4")
|
||||||
|
assert model_group_info is not None
|
||||||
|
assert model_group_info.model_group == "gpt-4"
|
||||||
|
assert model_group_info.input_cost_per_token > 0
|
||||||
|
assert model_group_info.output_cost_per_token > 0
|
|
@ -448,7 +448,7 @@ export default function ModelInfoView({
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
<Text className="font-medium">RPM VVV(Requests per Minute)</Text>
|
<Text className="font-medium">RPM (Requests per Minute)</Text>
|
||||||
{isEditing ? (
|
{isEditing ? (
|
||||||
<Form.Item name="rpm" className="mb-0">
|
<Form.Item name="rpm" className="mb-0">
|
||||||
<NumericalInput placeholder="Enter RPM" />
|
<NumericalInput placeholder="Enter RPM" />
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue