diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 7d29032c6..2f2450340 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -9,6 +9,7 @@ Run checks for: 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ +import re import time import traceback from datetime import datetime @@ -34,6 +35,7 @@ from litellm.proxy._types import ( ) from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics +from litellm.router_utils.pattern_match_deployments import PatternMatchRouter from litellm.types.services import ServiceLoggerPayload, ServiceTypes from .auth_checks_organization import organization_role_based_access_check @@ -48,8 +50,8 @@ else: last_db_access_time = LimitedSizeOrderedDict(max_size=100) db_cache_expiry = 5 # refresh every 5s - all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value +pattern_router = PatternMatchRouter() def common_checks( # noqa: PLR0915 @@ -90,17 +92,15 @@ def common_checks( # noqa: PLR0915 ): # this means the team has access to all models on the proxy if ( - "all-proxy-models" in team_object.models - or "*" in team_object.models - or "openai/*" in team_object.models + _model_is_within_list_of_allowed_models( + model=_model, allowed_models=team_object.models + ) + is True ): - # this means the team has access to all models on the proxy pass # check if the team model is an access_group elif model_in_access_group(_model, team_object.models) is True: pass - elif _model and "*" in _model: - pass else: raise Exception( f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}" @@ -828,7 +828,7 @@ async def can_key_call_model( model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth ) -> Literal[True]: """ - Checks if token can call a given model + Checks if token can call a given model, supporting regex/wildcard patterns Returns: - True: if token allowed to call model @@ -863,25 +863,81 @@ async def can_key_call_model( # Filter out models that are access_groups filtered_models = [m for m in valid_token.models if m not in access_groups] - filtered_models += models_in_current_access_groups verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") - all_model_access: bool = False - if ( - len(filtered_models) == 0 - or "*" in filtered_models - or "openai/*" in filtered_models + _model_is_within_list_of_allowed_models( + model=model, allowed_models=filtered_models + ) + is False ): - all_model_access = True - - if model is not None and model not in filtered_models and all_model_access is False: raise ValueError( f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" ) + valid_token.models = filtered_models verbose_proxy_logger.debug( f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}" ) return True + + +def _model_is_within_list_of_allowed_models( + model: str, allowed_models: List[str] +) -> bool: + """ + Checks if a model is within a list of allowed models (includes pattern matching checks) + + Args: + model (str): The model to check (e.g., "custom_engine/model-123") + allowed_models (List[str]): List of allowed model patterns (e.g., ["custom_engine/*", "azure/gpt-4", "claude-sonnet"]) + + Returns: + bool: True if + - len(allowed_models) == 0 + - "*" in allowed_models (means all models are allowed) + - "all-proxy-models" in allowed_models (means all models are allowed) + - model is in allowed_models list + - model matches any allowed pattern + """ + # Check for universal access patterns + if len(allowed_models) == 0: + return True + if "*" in allowed_models: + return True + if "all-proxy-models" in allowed_models: + return True + # Note: This is here to maintain backwards compatibility. This Used to be in our code and removing could impact existing customer code + if "openai/*" in allowed_models: + return True + if model in allowed_models: + return True + if _model_matches_patterns(model=model, allowed_models=allowed_models) is True: + return True + + return False + + +def _model_matches_patterns(model: str, allowed_models: List[str]) -> bool: + """ + Helper function to check if a model matches any of the allowed model patterns. + + Args: + model (str): The model to check (e.g., "custom_engine/model-123") + allowed_models (List[str]): List of allowed model patterns (e.g., ["custom_engine/*", "azure/gpt-4*"]) + + Returns: + bool: True if model matches any allowed pattern, False otherwise + """ + try: + # Create pattern router instance + for _model in allowed_models: + if "*" in _model: + regex_pattern = pattern_router._pattern_to_regex(_model) + if re.match(regex_pattern, model): + return True + return False + except Exception as e: + verbose_proxy_logger.exception(f"Error in model_matches_patterns: {str(e)}") + return False diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 13fb1bcbe..596335e16 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,24 +1,5 @@ model_list: - - model_name: gpt-4o + - model_name: custom_engine/* litellm_params: - model: openai/gpt-4o - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - - model_name: fake-anthropic-endpoint - litellm_params: - model: anthropic/fake - api_base: https://exampleanthropicendpoint-production.up.railway.app/ - -router_settings: - provider_budget_config: - openai: - budget_limit: 0.3 # float of $ value budget for time period - time_period: 1d # can be 1d, 2d, 30d - anthropic: - budget_limit: 5 - time_period: 1d - redis_host: os.environ/REDIS_HOST - redis_port: os.environ/REDIS_PORT - redis_password: os.environ/REDIS_PASSWORD - -litellm_settings: - callbacks: ["prometheus"] \ No newline at end of file + model: openai/custom_engine + api_base: https://exampleopenaiendpoint-production.up.railway.app/ \ No newline at end of file diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index b1d6b3dc6..991a58de4 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -85,6 +85,10 @@ model_list: litellm_params: model: openai/* api_key: os.environ/OPENAI_API_KEY + - model_name: custom_engine/* # regex pattern matching test for /models and /model/info endpoints + litellm_params: + model: openai/custom_engine + api_base: https://exampleopenaiendpoint-production.up.railway.app/ # provider specific wildcard routing @@ -96,6 +100,11 @@ model_list: litellm_params: model: "groq/*" api_key: os.environ/GROQ_API_KEY + - model_name: "custom-llm-engine/*" + litellm_params: + model: "openai/my-fake-model" + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + api_key: fake-key - model_name: mistral-embed litellm_params: model: mistral/mistral-embed diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index 31daa358a..da883b7e6 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -15,6 +15,10 @@ from starlette.datastructures import URL import litellm from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.auth.auth_checks import ( + _model_is_within_list_of_allowed_models, + _model_matches_patterns, +) class Request: @@ -387,3 +391,62 @@ def test_is_api_route_allowed(route, user_role, expected_result): pass else: raise e + + +@pytest.mark.parametrize( + "model, allowed_models, expected_result", + [ + # Empty allowed models list + ("gpt-4", [], True), + # Universal access patterns + ("gpt-4", ["*"], True), + ("gpt-4", ["all-proxy-models"], True), + # Exact matches + ("gpt-4", ["gpt-4"], True), + ("gpt-4", ["gpt-3.5", "claude-2"], False), + # Pattern matching + ("azure/gpt-4", ["azure/*"], True), + ("custom_engine/model-123", ["custom_engine/*"], True), + ("custom_engine/model-123", ["custom_engine/*", "azure/*"], True), + ("custom-engine/model-123", ["custom_engine/*", "azure/*"], False), + ("openai/gpt-4o", ["openai/*"], True), + ("openai/gpt-12", ["openai/*"], True), + ("gpt-4", ["gpt-*"], True), + ("gpt-4", ["claude-*"], False), + # Mixed scenarios + ("gpt-4", ["claude-2", "gpt-*", "palm-2"], True), + ("anthropic/claude-instant-1", ["anthropic/*", "gpt-4"], True), + ], +) +def test_model_is_within_list_of_allowed_models(model, allowed_models, expected_result): + result = _model_is_within_list_of_allowed_models( + model=model, allowed_models=allowed_models + ) + assert result == expected_result + + +@pytest.mark.parametrize( + "model, allowed_models, expected_result", + [ + # Basic pattern matching + ("gpt-4", ["gpt-*"], True), + ("azure/gpt-4", ["azure/*"], True), + ("custom_engine/model-123", ["custom_engine/*"], True), + ("openai/my-fake-model", ["openai/*"], True), + ("custom_engine/model-123", ["custom_engine/*", "azure/*"], True), + ("custom-engine/model-123", ["custom_engine/*", "azure/*"], False), + # Multiple patterns + ("gpt-4", ["claude-*", "gpt-*"], True), + ("anthropic/claude-instant-1", ["anthropic/*", "gpt-*"], True), + # No matches + ("gpt-4", ["claude-*"], False), + ("azure/gpt-4", ["anthropic/*"], False), + # Edge cases + ("gpt-4", ["gpt-4"], False), # No wildcard, should return False + ("model", ["*-suffix"], False), + ("prefix-model", ["prefix-*"], True), + ], +) +def test_model_matches_patterns(model, allowed_models, expected_result): + result = _model_matches_patterns(model=model, allowed_models=allowed_models) + assert result == expected_result diff --git a/tests/test_models.py b/tests/test_models.py index 959fee016..2bf3fdcc6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -48,6 +48,8 @@ async def get_models(session, key): if status != 200: raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + @pytest.mark.asyncio async def test_get_models(): @@ -362,3 +364,60 @@ async def test_add_model_run_health(): # cleanup await delete_model(session=session, model_id=model_id) + + +@pytest.mark.asyncio +async def test_wildcard_model_access(): + """ + Test key generation with wildcard model access pattern (custom_llm/*) + - Generate key with access to 'custom_llm/*' + - Call /models and /model/info to verify access + """ + async with aiohttp.ClientSession() as session: + # Generate key with wildcard access + key_gen = await generate_key(session=session, models=["custom_engine/*"]) + key = key_gen["key"] + + # Get models list + print("\nTesting /models endpoint with wildcard key") + models_response = await get_models(session=session, key=key) + + # verify /models response + _data = models_response["data"] + found_custom_engine_model = False + for model in _data: + if model["id"] == "custom_engine/*": + found_custom_engine_model = True + assert model["object"] == "model", "Incorrect object type" + assert model["owned_by"] == "openai", "Incorrect owner" + break + assert ( + found_custom_engine_model is True + ), "custom_engine/* model not found in response" + + # Get detailed model info + print("\nTesting /model/info endpoint with wildcard key") + model_info_response = await get_model_info(session=session, key=key) + print("Model info response:", model_info_response) + + # Add assertions to verify response content + assert "data" in model_info_response + model_data = model_info_response["data"] + assert len(model_data) > 0 + + # Find and verify the custom_engine/* model + custom_engine_model = None + for model in model_data: + if model["model_name"] == "custom_engine/*": + custom_engine_model = model + break + + assert ( + custom_engine_model is not None + ), "custom_engine/* model not found in response" + assert ( + custom_engine_model["litellm_params"]["api_base"] + == "https://exampleopenaiendpoint-production.up.railway.app/" + ) + assert custom_engine_model["litellm_params"]["model"] == "openai/custom_engine" + assert custom_engine_model["model_info"]["litellm_provider"] == "openai" diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 4dbeda188..3d1296b96 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -473,6 +473,24 @@ async def test_openai_wildcard_chat_completion(): await chat_completion(session=session, key=key, model="gpt-3.5-turbo-0125") +@pytest.mark.asyncio +async def test_regex_pattern_matching_e2e_test(): + """ + - Create key for model = "custom-llm-engine/*" -> this has access to all models matching pattern = "custom-llm-engine/*" + - proxy_server_config.yaml has model = "custom-llm-engine/*" + - Make chat completion call + + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session, models=["custom-llm-engine/*"]) + key = key_gen["key"] + + # call chat/completions with a model that the key was not created for + the model is not on the config.yaml + await chat_completion( + session=session, key=key, model="custom-llm-engine/very-new-model" + ) + + @pytest.mark.asyncio async def test_proxy_all_models(): """