forked from phoenix/litellm-mirror
test_is_ui_route_allowed
This commit is contained in:
parent
cdb94ffe16
commit
574f07d782
3 changed files with 50 additions and 21 deletions
|
@ -28,7 +28,7 @@ from litellm.proxy._types import (
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.route_checks import is_llm_api_route
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
|
||||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ def common_checks( # noqa: PLR0915
|
||||||
general_settings.get("enforce_user_param", None) is not None
|
general_settings.get("enforce_user_param", None) is not None
|
||||||
and general_settings["enforce_user_param"] is True
|
and general_settings["enforce_user_param"] is True
|
||||||
):
|
):
|
||||||
if is_llm_api_route(route=route) and "user" not in request_body:
|
if RouteChecks.is_llm_api_route(route=route) and "user" not in request_body:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
||||||
)
|
)
|
||||||
|
@ -154,7 +154,7 @@ def common_checks( # noqa: PLR0915
|
||||||
+ CommonProxyErrors.not_premium_user.value
|
+ CommonProxyErrors.not_premium_user.value
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_llm_api_route(route=route):
|
if RouteChecks.is_llm_api_route(route=route):
|
||||||
# loop through each enforced param
|
# loop through each enforced param
|
||||||
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
|
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
|
||||||
for enforced_param in general_settings["enforced_params"]:
|
for enforced_param in general_settings["enforced_params"]:
|
||||||
|
@ -182,7 +182,7 @@ def common_checks( # noqa: PLR0915
|
||||||
and global_proxy_spend is not None
|
and global_proxy_spend is not None
|
||||||
# only run global budget checks for OpenAI routes
|
# only run global budget checks for OpenAI routes
|
||||||
# Reason - the Admin UI should continue working if the proxy crosses it's global budget
|
# Reason - the Admin UI should continue working if the proxy crosses it's global budget
|
||||||
and is_llm_api_route(route=route)
|
and RouteChecks.is_llm_api_route(route=route)
|
||||||
and route != "/v1/models"
|
and route != "/v1/models"
|
||||||
and route != "/models"
|
and route != "/models"
|
||||||
):
|
):
|
||||||
|
|
|
@ -122,6 +122,11 @@ def _is_ui_route_allowed(
|
||||||
):
|
):
|
||||||
# Do something if the current route starts with any of the allowed routes
|
# Do something if the current route starts with any of the allowed routes
|
||||||
return True
|
return True
|
||||||
|
elif any(
|
||||||
|
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
||||||
|
for allowed_route in allowed_routes
|
||||||
|
):
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -293,26 +293,50 @@ async def test_auth_with_allowed_routes(route, should_raise_error):
|
||||||
setattr(proxy_server, "general_settings", initial_general_settings)
|
setattr(proxy_server, "general_settings", initial_general_settings)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("route", ["/global/spend/logs", "/key/delete"])
|
@pytest.mark.parametrize(
|
||||||
def test_is_ui_route_allowed(route):
|
"route, user_role, expected_result",
|
||||||
|
[
|
||||||
|
# Proxy Admin checks
|
||||||
|
("/global/spend/logs", "proxy_admin", True),
|
||||||
|
("/key/delete", "proxy_admin", True),
|
||||||
|
("/key/generate", "proxy_admin", True),
|
||||||
|
("/key/regenerate", "proxy_admin", True),
|
||||||
|
# Internal User checks - allowed routes
|
||||||
|
("/global/spend/logs", "internal_user", True),
|
||||||
|
("/key/delete", "internal_user", True),
|
||||||
|
("/key/generate", "internal_user", True),
|
||||||
|
("/key/82akk800000000jjsk/regenerate", "internal_user", True),
|
||||||
|
# Internal User checks - disallowed routes
|
||||||
|
("/organization/member_add", "internal_user", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_ui_route_allowed(route, user_role, expected_result):
|
||||||
from litellm.proxy.auth.user_api_key_auth import _is_ui_route_allowed
|
from litellm.proxy.auth.user_api_key_auth import _is_ui_route_allowed
|
||||||
from litellm.proxy._types import LiteLLM_UserTable
|
from litellm.proxy._types import LiteLLM_UserTable
|
||||||
|
|
||||||
|
user_obj = LiteLLM_UserTable(
|
||||||
|
user_id="3b803c0e-666e-4e99-bd5c-6e534c07e297",
|
||||||
|
max_budget=None,
|
||||||
|
spend=0.0,
|
||||||
|
model_max_budget={},
|
||||||
|
model_spend={},
|
||||||
|
user_email="my-test-email@1234.com",
|
||||||
|
models=[],
|
||||||
|
tpm_limit=None,
|
||||||
|
rpm_limit=None,
|
||||||
|
user_role=user_role,
|
||||||
|
organization_memberships=[],
|
||||||
|
)
|
||||||
|
|
||||||
received_args: dict = {
|
received_args: dict = {
|
||||||
"route": route,
|
"route": route,
|
||||||
"user_obj": LiteLLM_UserTable(
|
"user_obj": user_obj,
|
||||||
user_id="3b803c0e-666e-4e99-bd5c-6e534c07e297",
|
|
||||||
max_budget=None,
|
|
||||||
spend=0.0,
|
|
||||||
model_max_budget={},
|
|
||||||
model_spend={},
|
|
||||||
user_email="my-test-email@1234.com",
|
|
||||||
models=[],
|
|
||||||
tpm_limit=None,
|
|
||||||
rpm_limit=None,
|
|
||||||
user_role="internal_user",
|
|
||||||
organization_memberships=[],
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
try:
|
||||||
assert _is_ui_route_allowed(**received_args)
|
assert _is_ui_route_allowed(**received_args) == expected_result
|
||||||
|
except Exception as e:
|
||||||
|
# If expected result is False, we expect an error
|
||||||
|
if expected_result is False:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue