From 0e36333051d294dc9815c29e460a70116dc67da5 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 13:50:33 -0800 Subject: [PATCH 01/10] add check for model_matches_patterns --- litellm/proxy/auth/auth_checks.py | 48 +++++++++++++++++++++++-------- litellm/proxy/proxy_config.yaml | 25 ++-------------- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 7d29032c6..b1c27a409 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -9,6 +9,7 @@ Run checks for: 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ +import re import time import traceback from datetime import datetime @@ -34,6 +35,7 @@ from litellm.proxy._types import ( ) from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics +from litellm.router_utils.pattern_match_deployments import PatternMatchRouter from litellm.types.services import ServiceLoggerPayload, ServiceTypes from .auth_checks_organization import organization_role_based_access_check @@ -48,8 +50,8 @@ else: last_db_access_time = LimitedSizeOrderedDict(max_size=100) db_cache_expiry = 5 # refresh every 5s - all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value +pattern_router = PatternMatchRouter() def common_checks( # noqa: PLR0915 @@ -828,7 +830,7 @@ async def can_key_call_model( model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth ) -> Literal[True]: """ - Checks if token can call a given model + Checks if token can call a given model, supporting regex/wildcard patterns Returns: - True: if token allowed to call model @@ -863,20 +865,18 @@ async def can_key_call_model( # Filter out models that are access_groups filtered_models = [m for m in valid_token.models if m not in access_groups] - filtered_models += models_in_current_access_groups verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") - all_model_access: bool = False + # Check for universal access patterns + if len(filtered_models) == 0: + return True + if "*" in filtered_models: + return True + if model_matches_patterns(model=model, allowed_models=filtered_models) is True: + return True - if ( - len(filtered_models) == 0 - or "*" in filtered_models - or "openai/*" in filtered_models - ): - all_model_access = True - - if model is not None and model not in filtered_models and all_model_access is False: + if model is not None and model not in filtered_models: raise ValueError( f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" ) @@ -885,3 +885,27 @@ async def can_key_call_model( f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}" ) return True + + +def model_matches_patterns(model: str, allowed_models: List[str]) -> bool: + """ + Helper function to check if a model matches any of the allowed model patterns. + + Args: + model (str): The model to check (e.g., "custom_engine/model-123") + allowed_models (List[str]): List of allowed model patterns (e.g., ["custom_engine/*", "azure/gpt-4*"]) + + Returns: + bool: True if model matches any allowed pattern, False otherwise + """ + try: + # Create pattern router instance + for _model in allowed_models: + if "*" in _model: + regex_pattern = pattern_router._pattern_to_regex(_model) + if re.match(regex_pattern, model): + return True + return False + except Exception as e: + verbose_proxy_logger.exception(f"Error in model_matches_patterns: {str(e)}") + return False diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 13fb1bcbe..596335e16 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,24 +1,5 @@ model_list: - - model_name: gpt-4o + - model_name: custom_engine/* litellm_params: - model: openai/gpt-4o - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - - model_name: fake-anthropic-endpoint - litellm_params: - model: anthropic/fake - api_base: https://exampleanthropicendpoint-production.up.railway.app/ - -router_settings: - provider_budget_config: - openai: - budget_limit: 0.3 # float of $ value budget for time period - time_period: 1d # can be 1d, 2d, 30d - anthropic: - budget_limit: 5 - time_period: 1d - redis_host: os.environ/REDIS_HOST - redis_port: os.environ/REDIS_PORT - redis_password: os.environ/REDIS_PASSWORD - -litellm_settings: - callbacks: ["prometheus"] \ No newline at end of file + model: openai/custom_engine + api_base: https://exampleopenaiendpoint-production.up.railway.app/ \ No newline at end of file From e9cdbff75a3721748f7223a55ddc498049fec5ed Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 13:58:08 -0800 Subject: [PATCH 02/10] add check for _model_is_within_list_of_allowed_models --- litellm/proxy/auth/auth_checks.py | 46 +++++++++++++++++++------------ 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index b1c27a409..52f7f22fe 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -92,17 +92,15 @@ def common_checks( # noqa: PLR0915 ): # this means the team has access to all models on the proxy if ( - "all-proxy-models" in team_object.models - or "*" in team_object.models - or "openai/*" in team_object.models + _model_is_within_list_of_allowed_models( + model=_model, allowed_models=team_object.models + ) + is True ): - # this means the team has access to all models on the proxy pass # check if the team model is an access_group elif model_in_access_group(_model, team_object.models) is True: pass - elif _model and "*" in _model: - pass else: raise Exception( f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}" @@ -868,18 +866,16 @@ async def can_key_call_model( filtered_models += models_in_current_access_groups verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") - # Check for universal access patterns - if len(filtered_models) == 0: - return True - if "*" in filtered_models: - return True - if model_matches_patterns(model=model, allowed_models=filtered_models) is True: - return True - - if model is not None and model not in filtered_models: - raise ValueError( - f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" + if ( + _model_is_within_list_of_allowed_models( + model=model, allowed_models=filtered_models ) + is False + ): + raise ValueError( + f"API Key not allowed to access model. List of allowed models={filtered_models}. Tried to access {model}" + ) + valid_token.models = filtered_models verbose_proxy_logger.debug( f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}" @@ -887,6 +883,22 @@ async def can_key_call_model( return True +def _model_is_within_list_of_allowed_models( + model: str, allowed_models: List[str] +) -> bool: + # Check for universal access patterns + if len(allowed_models) == 0: + return True + if "*" in allowed_models: + return True + if "all-proxy-models" in allowed_models: + return True + if model_matches_patterns(model=model, allowed_models=allowed_models) is True: + return True + + return False + + def model_matches_patterns(model: str, allowed_models: List[str]) -> bool: """ Helper function to check if a model matches any of the allowed model patterns. From ba4e1a7a0d9228a4624b1c06e241ae699bb34cac Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 14:02:31 -0800 Subject: [PATCH 03/10] fix model auth checks --- litellm/proxy/auth/auth_checks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 52f7f22fe..ac93315f8 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -893,6 +893,8 @@ def _model_is_within_list_of_allowed_models( return True if "all-proxy-models" in allowed_models: return True + if model in allowed_models: + return True if model_matches_patterns(model=model, allowed_models=allowed_models) is True: return True From 76dec9a0e8488b576a5abe32c8d90497de3b112c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 14:32:26 -0800 Subject: [PATCH 04/10] fix error msg --- litellm/proxy/auth/auth_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index ac93315f8..d5485fdc0 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -873,7 +873,7 @@ async def can_key_call_model( is False ): raise ValueError( - f"API Key not allowed to access model. List of allowed models={filtered_models}. Tried to access {model}" + f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" ) valid_token.models = filtered_models From 79d51e0bdd143385cdd27ca0fafba48f9732c565 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 15:31:56 -0800 Subject: [PATCH 05/10] add doc string --- litellm/proxy/auth/auth_checks.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index d5485fdc0..b47cb19a5 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -886,6 +886,21 @@ async def can_key_call_model( def _model_is_within_list_of_allowed_models( model: str, allowed_models: List[str] ) -> bool: + """ + Checks if a model is within a list of allowed models (includes pattern matching checks) + + Args: + model (str): The model to check (e.g., "custom_engine/model-123") + allowed_models (List[str]): List of allowed model patterns (e.g., ["custom_engine/*", "azure/gpt-4", "claude-sonnet"]) + + Returns: + bool: True if + - len(allowed_models) == 0 + - "*" in allowed_models (means all models are allowed) + - "all-proxy-models" in allowed_models (means all models are allowed) + - model is in allowed_models list + - model matches any allowed pattern + """ # Check for universal access patterns if len(allowed_models) == 0: return True @@ -895,13 +910,13 @@ def _model_is_within_list_of_allowed_models( return True if model in allowed_models: return True - if model_matches_patterns(model=model, allowed_models=allowed_models) is True: + if _model_matches_patterns(model=model, allowed_models=allowed_models) is True: return True return False -def model_matches_patterns(model: str, allowed_models: List[str]) -> bool: +def _model_matches_patterns(model: str, allowed_models: List[str]) -> bool: """ Helper function to check if a model matches any of the allowed model patterns. From f76d0e41ef513a755352b05cbe2777e48ce6f01e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 15:35:53 -0800 Subject: [PATCH 06/10] add unit testing for new model regex pattern matching --- .../test_user_api_key_auth.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index 31daa358a..913ff34e1 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -15,6 +15,10 @@ from starlette.datastructures import URL import litellm from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.auth.auth_checks import ( + _model_is_within_list_of_allowed_models, + _model_matches_patterns, +) class Request: @@ -387,3 +391,59 @@ def test_is_api_route_allowed(route, user_role, expected_result): pass else: raise e + + +@pytest.mark.parametrize( + "model, allowed_models, expected_result", + [ + # Empty allowed models list + ("gpt-4", [], True), + # Universal access patterns + ("gpt-4", ["*"], True), + ("gpt-4", ["all-proxy-models"], True), + # Exact matches + ("gpt-4", ["gpt-4"], True), + ("gpt-4", ["gpt-3.5", "claude-2"], False), + # Pattern matching + ("azure/gpt-4", ["azure/*"], True), + ("custom_engine/model-123", ["custom_engine/*"], True), + ("custom_engine/model-123", ["custom_engine/*", "azure/*"], True), + ("custom-engine/model-123", ["custom_engine/*", "azure/*"], False), + ("gpt-4", ["gpt-*"], True), + ("gpt-4", ["claude-*"], False), + # Mixed scenarios + ("gpt-4", ["claude-2", "gpt-*", "palm-2"], True), + ("anthropic/claude-instant-1", ["anthropic/*", "gpt-4"], True), + ], +) +def test_model_is_within_list_of_allowed_models(model, allowed_models, expected_result): + result = _model_is_within_list_of_allowed_models( + model=model, allowed_models=allowed_models + ) + assert result == expected_result + + +@pytest.mark.parametrize( + "model, allowed_models, expected_result", + [ + # Basic pattern matching + ("gpt-4", ["gpt-*"], True), + ("azure/gpt-4", ["azure/*"], True), + ("custom_engine/model-123", ["custom_engine/*"], True), + ("custom_engine/model-123", ["custom_engine/*", "azure/*"], True), + ("custom-engine/model-123", ["custom_engine/*", "azure/*"], False), + # Multiple patterns + ("gpt-4", ["claude-*", "gpt-*"], True), + ("anthropic/claude-instant-1", ["anthropic/*", "gpt-*"], True), + # No matches + ("gpt-4", ["claude-*"], False), + ("azure/gpt-4", ["anthropic/*"], False), + # Edge cases + ("gpt-4", ["gpt-4"], False), # No wildcard, should return False + ("model", ["*-suffix"], False), + ("prefix-model", ["prefix-*"], True), + ], +) +def test_model_matches_patterns(model, allowed_models, expected_result): + result = _model_matches_patterns(model=model, allowed_models=allowed_models) + assert result == expected_result From a18aeaa2fbeae51e463fa69ebe06813b8a78d9cc Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 15:39:42 -0800 Subject: [PATCH 07/10] add test_regex_pattern_matching_e2e_test --- proxy_server_config.yaml | 5 +++++ tests/test_openai_endpoints.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index b1d6b3dc6..4b1c21a7b 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -96,6 +96,11 @@ model_list: litellm_params: model: "groq/*" api_key: os.environ/GROQ_API_KEY + - model_name: "custom-llm-engine/*" + litellm_params: + model: "openai/my-fake-model" + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + api_key: fake-key - model_name: mistral-embed litellm_params: model: mistral/mistral-embed diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 4dbeda188..3d1296b96 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -473,6 +473,24 @@ async def test_openai_wildcard_chat_completion(): await chat_completion(session=session, key=key, model="gpt-3.5-turbo-0125") +@pytest.mark.asyncio +async def test_regex_pattern_matching_e2e_test(): + """ + - Create key for model = "custom-llm-engine/*" -> this has access to all models matching pattern = "custom-llm-engine/*" + - proxy_server_config.yaml has model = "custom-llm-engine/*" + - Make chat completion call + + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session, models=["custom-llm-engine/*"]) + key = key_gen["key"] + + # call chat/completions with a model that the key was not created for + the model is not on the config.yaml + await chat_completion( + session=session, key=key, model="custom-llm-engine/very-new-model" + ) + + @pytest.mark.asyncio async def test_proxy_all_models(): """ From 632283a371abcb4f5212e563fe3142af4074d7b7 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 15:48:40 -0800 Subject: [PATCH 08/10] add coverage --- tests/proxy_unit_tests/test_user_api_key_auth.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index 913ff34e1..da883b7e6 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -409,6 +409,8 @@ def test_is_api_route_allowed(route, user_role, expected_result): ("custom_engine/model-123", ["custom_engine/*"], True), ("custom_engine/model-123", ["custom_engine/*", "azure/*"], True), ("custom-engine/model-123", ["custom_engine/*", "azure/*"], False), + ("openai/gpt-4o", ["openai/*"], True), + ("openai/gpt-12", ["openai/*"], True), ("gpt-4", ["gpt-*"], True), ("gpt-4", ["claude-*"], False), # Mixed scenarios @@ -430,6 +432,7 @@ def test_model_is_within_list_of_allowed_models(model, allowed_models, expected_ ("gpt-4", ["gpt-*"], True), ("azure/gpt-4", ["azure/*"], True), ("custom_engine/model-123", ["custom_engine/*"], True), + ("openai/my-fake-model", ["openai/*"], True), ("custom_engine/model-123", ["custom_engine/*", "azure/*"], True), ("custom-engine/model-123", ["custom_engine/*", "azure/*"], False), # Multiple patterns From a4491658e3e04fcf00e492b62ce6ae5065487dfe Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 15:50:58 -0800 Subject: [PATCH 09/10] add "openai/*" in allowed models check --- litellm/proxy/auth/auth_checks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index b47cb19a5..2f2450340 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -908,6 +908,9 @@ def _model_is_within_list_of_allowed_models( return True if "all-proxy-models" in allowed_models: return True + # Note: This is here to maintain backwards compatibility. This Used to be in our code and removing could impact existing customer code + if "openai/*" in allowed_models: + return True if model in allowed_models: return True if _model_matches_patterns(model=model, allowed_models=allowed_models) is True: From 6f34169f178157ef1df2b33ae4b11676cfec64fd Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 27 Nov 2024 18:38:36 -0800 Subject: [PATCH 10/10] add e2e tests for keys with regex patterns for /models and /model/info --- proxy_server_config.yaml | 4 +++ tests/test_models.py | 59 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 4b1c21a7b..991a58de4 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -85,6 +85,10 @@ model_list: litellm_params: model: openai/* api_key: os.environ/OPENAI_API_KEY + - model_name: custom_engine/* # regex pattern matching test for /models and /model/info endpoints + litellm_params: + model: openai/custom_engine + api_base: https://exampleopenaiendpoint-production.up.railway.app/ # provider specific wildcard routing diff --git a/tests/test_models.py b/tests/test_models.py index 959fee016..2bf3fdcc6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -48,6 +48,8 @@ async def get_models(session, key): if status != 200: raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + @pytest.mark.asyncio async def test_get_models(): @@ -362,3 +364,60 @@ async def test_add_model_run_health(): # cleanup await delete_model(session=session, model_id=model_id) + + +@pytest.mark.asyncio +async def test_wildcard_model_access(): + """ + Test key generation with wildcard model access pattern (custom_llm/*) + - Generate key with access to 'custom_llm/*' + - Call /models and /model/info to verify access + """ + async with aiohttp.ClientSession() as session: + # Generate key with wildcard access + key_gen = await generate_key(session=session, models=["custom_engine/*"]) + key = key_gen["key"] + + # Get models list + print("\nTesting /models endpoint with wildcard key") + models_response = await get_models(session=session, key=key) + + # verify /models response + _data = models_response["data"] + found_custom_engine_model = False + for model in _data: + if model["id"] == "custom_engine/*": + found_custom_engine_model = True + assert model["object"] == "model", "Incorrect object type" + assert model["owned_by"] == "openai", "Incorrect owner" + break + assert ( + found_custom_engine_model is True + ), "custom_engine/* model not found in response" + + # Get detailed model info + print("\nTesting /model/info endpoint with wildcard key") + model_info_response = await get_model_info(session=session, key=key) + print("Model info response:", model_info_response) + + # Add assertions to verify response content + assert "data" in model_info_response + model_data = model_info_response["data"] + assert len(model_data) > 0 + + # Find and verify the custom_engine/* model + custom_engine_model = None + for model in model_data: + if model["model_name"] == "custom_engine/*": + custom_engine_model = model + break + + assert ( + custom_engine_model is not None + ), "custom_engine/* model not found in response" + assert ( + custom_engine_model["litellm_params"]["api_base"] + == "https://exampleopenaiendpoint-production.up.railway.app/" + ) + assert custom_engine_model["litellm_params"]["model"] == "openai/custom_engine" + assert custom_engine_model["model_info"]["litellm_provider"] == "openai"