diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 270dab7f21..5a456aec97 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1992,6 +1992,8 @@ class SpendCalculateRequest(LiteLLMPydanticObjectBase): class ProxyErrorTypes(str, enum.Enum): budget_exceeded = "budget_exceeded" + key_model_access_denied = "key_model_access_denied" + team_model_access_denied = "team_model_access_denied" expired_key = "expired_key" auth_error = "auth_error" internal_server_error = "internal_server_error" diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index c150534110..d6bbf760bd 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -14,12 +14,14 @@ import time import traceback from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from fastapi import status from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache from litellm.caching.dual_cache import LimitedSizeOrderedDict +from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, CallInfo, @@ -31,6 +33,8 @@ from litellm.proxy._types import ( LiteLLM_UserTable, LiteLLMRoutes, LitellmUserRoles, + ProxyErrorTypes, + ProxyException, UserAPIKeyAuth, ) from litellm.proxy.auth.route_checks import RouteChecks @@ -887,8 +891,11 @@ async def can_key_call_model( all_model_access = True if model is not None and model not in filtered_models and all_model_access is False: - raise ValueError( - f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" + raise ProxyException( + message=f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}", + type=ProxyErrorTypes.key_model_access_denied, + param="model", + code=status.HTTP_401_UNAUTHORIZED, ) valid_token.models = filtered_models verbose_proxy_logger.debug( @@ -1064,11 +1071,7 @@ def _team_model_access_check( and model not in team_object.models ): # this means the team has access to all models on the proxy - if ( - "all-proxy-models" in team_object.models - or "*" in team_object.models - or "openai/*" in team_object.models - ): + if "all-proxy-models" in team_object.models or "*" in team_object.models: # this means the team has access to all models on the proxy pass # check if the team model is an access_group @@ -1086,8 +1089,11 @@ def _team_model_access_check( ): pass else: - raise Exception( - f"Team={team_object.team_id} not allowed to call model={model}. Allowed team models = {team_object.models}" + raise ProxyException( + message=f"Team not allowed to access model. Team={team_object.team_id}, Model={model}. Allowed team models = {team_object.models}", + type=ProxyErrorTypes.team_model_access_denied, + param="model", + code=status.HTTP_401_UNAUTHORIZED, ) @@ -1121,10 +1127,51 @@ def _model_matches_any_wildcard_pattern_in_list( - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True - model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False """ - return any( - "*" in allowed_model_pattern + + if any( + _is_wildcard_pattern(allowed_model_pattern) and is_model_allowed_by_pattern( model=model, allowed_model_pattern=allowed_model_pattern ) for allowed_model_pattern in allowed_model_list + ): + return True + + if any( + _is_wildcard_pattern(allowed_model_pattern) + and _model_custom_llm_provider_matches_wildcard_pattern( + model=model, allowed_model_pattern=allowed_model_pattern + ) + for allowed_model_pattern in allowed_model_list + ): + return True + + return False + + +def _model_custom_llm_provider_matches_wildcard_pattern( + model: str, allowed_model_pattern: str +) -> bool: + """ + Returns True for this scenario: + - `model=gpt-4o` + - `allowed_model_pattern=openai/*` + + or + - `model=claude-3-5-sonnet-20240620` + - `allowed_model_pattern=anthropic/*` + """ + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + return is_model_allowed_by_pattern( + model=f"{custom_llm_provider}/{model}", + allowed_model_pattern=allowed_model_pattern, ) + + +def _is_wildcard_pattern(allowed_model_pattern: str) -> bool: + """ + Returns True if the pattern is a wildcard pattern. + + Checks if `*` is in the pattern. + """ + return "*" in allowed_model_pattern diff --git a/litellm/proxy/example_config_yaml/otel_test_config.yaml b/litellm/proxy/example_config_yaml/otel_test_config.yaml index d3f1236159..3247516296 100644 --- a/litellm/proxy/example_config_yaml/otel_test_config.yaml +++ b/litellm/proxy/example_config_yaml/otel_test_config.yaml @@ -31,7 +31,15 @@ model_list: api_key: fake-key model_info: supports_vision: True - + - model_name: bedrock/* + litellm_params: + model: bedrock/* + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + - model_name: openai/* + litellm_params: + model: openai/* + api_key: os.environ/OPENAI_API_KEY + api_base: https://exampleopenaiendpoint-production.up.railway.app/ litellm_settings: diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 791dac4e8c..f1d27c35da 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -3,6 +3,10 @@ model_list: litellm_params: model: openai/* api_key: os.environ/OPENAI_API_KEY + - model_name: bedrock/* + litellm_params: + model: bedrock/* + api_base: https://exampleopenaiendpoint-production.up.railway.app/ - model_name: text-embedding-ada-002 litellm_params: model: openai/text-embedding-ada-002 @@ -14,6 +18,12 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app/ + + +litellm_settings: + callbacks: ["langfuse"] + + litellm_settings: callbacks: ["prometheus"] prometheus_initialize_budget_metrics: true diff --git a/tests/otel_tests/test_e2e_model_access.py b/tests/otel_tests/test_e2e_model_access.py new file mode 100644 index 0000000000..73c93212bf --- /dev/null +++ b/tests/otel_tests/test_e2e_model_access.py @@ -0,0 +1,290 @@ +import pytest +import asyncio +import aiohttp +import json +from httpx import AsyncClient +from typing import Any, Optional, List, Literal + + +async def generate_key( + session, models: Optional[List[str]] = None, team_id: Optional[str] = None +): + """Helper function to generate a key with specific model access""" + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {} + if models is not None: + data["models"] = models + if team_id is not None: + data["team_id"] = team_id + async with session.post(url, headers=headers, json=data) as response: + return await response.json() + + +async def generate_team(session, models: Optional[List[str]] = None): + """Helper function to generate a team with specific model access""" + url = "http://0.0.0.0:4000/team/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = {} + if models is not None: + data["models"] = models + async with session.post(url, headers=headers, json=data) as response: + return await response.json() + + +async def mock_chat_completion(session, key: str, model: str): + """Make a chat completion request using OpenAI SDK""" + from openai import AsyncOpenAI + import uuid + + client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000/v1") + + response = await client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": f"Say hello! {uuid.uuid4()}"}], + extra_body={ + "mock_response": "mock_response", + }, + ) + return response + + +@pytest.mark.parametrize( + "key_models, test_model, expect_success", + [ + (["openai/*"], "anthropic/claude-2", False), # Non-matching model + (["gpt-4"], "gpt-4", True), # Exact model match + (["bedrock/*"], "bedrock/anthropic.claude-3", True), # Bedrock wildcard + (["bedrock/anthropic.*"], "bedrock/anthropic.claude-3", True), # Pattern match + (["bedrock/anthropic.*"], "bedrock/amazon.titan", False), # Pattern non-match + (None, "gpt-4", True), # No model restrictions + ([], "gpt-4", True), # Empty model list + ], +) +@pytest.mark.asyncio +async def test_model_access_patterns(key_models, test_model, expect_success): + """ + Test model access patterns for API keys: + 1. Create key with specific model access pattern + 2. Attempt to make completion with test model + 3. Verify access is granted/denied as expected + """ + async with aiohttp.ClientSession() as session: + # Generate key with specified model access + key_gen = await generate_key(session=session, models=key_models) + key = key_gen["key"] + + try: + response = await mock_chat_completion( + session=session, + key=key, + model=test_model, + ) + if not expect_success: + pytest.fail(f"Expected request to fail for model {test_model}") + assert ( + response is not None + ), "Should get valid response when access is allowed" + except Exception as e: + if expect_success: + pytest.fail(f"Expected request to succeed but got error: {e}") + _error_body = e.body + + # Assert error structure and values + assert _error_body["type"] == "key_model_access_denied" + assert _error_body["param"] == "model" + assert _error_body["code"] == "401" + assert "API Key not allowed to access model" in _error_body["message"] + + +@pytest.mark.asyncio +async def test_model_access_update(): + """ + Test updating model access for an existing key: + 1. Create key with restricted model access + 2. Verify access patterns + 3. Update key with new model access + 4. Verify new access patterns + """ + client = AsyncClient(base_url="http://0.0.0.0:4000") + headers = {"Authorization": "Bearer sk-1234"} + + # Create initial key with restricted access + response = await client.post( + "/key/generate", json={"models": ["openai/gpt-4"]}, headers=headers + ) + assert response.status_code == 200 + key_data = response.json() + key = key_data["key"] + + # Test initial access + async with aiohttp.ClientSession() as session: + # Should work with gpt-4 + await mock_chat_completion(session=session, key=key, model="openai/gpt-4") + + # Should fail with gpt-3.5-turbo + with pytest.raises(Exception) as exc_info: + await mock_chat_completion( + session=session, key=key, model="openai/gpt-3.5-turbo" + ) + _validate_model_access_exception( + exc_info.value, expected_type="key_model_access_denied" + ) + + # Update key with new model access + response = await client.post( + "/key/update", json={"key": key, "models": ["openai/*"]}, headers=headers + ) + assert response.status_code == 200 + + # Test updated access + async with aiohttp.ClientSession() as session: + # Both models should now work + await mock_chat_completion(session=session, key=key, model="openai/gpt-4") + await mock_chat_completion( + session=session, key=key, model="openai/gpt-3.5-turbo" + ) + + # Non-OpenAI model should still fail + with pytest.raises(Exception) as exc_info: + await mock_chat_completion( + session=session, key=key, model="anthropic/claude-2" + ) + _validate_model_access_exception( + exc_info.value, expected_type="key_model_access_denied" + ) + + +@pytest.mark.parametrize( + "team_models, test_model, expect_success", + [ + (["openai/*"], "anthropic/claude-2", False), # Non-matching model + (["gpt-4"], "gpt-4", True), # Exact model match + (["bedrock/*"], "bedrock/anthropic.claude-3", True), # Bedrock wildcard + (["bedrock/anthropic.*"], "bedrock/anthropic.claude-3", True), # Pattern match + (["bedrock/anthropic.*"], "bedrock/amazon.titan", False), # Pattern non-match + (None, "gpt-4", True), # No model restrictions + ([], "gpt-4", True), # Empty model list + ], +) +@pytest.mark.asyncio +async def test_team_model_access_patterns(team_models, test_model, expect_success): + """ + Test model access patterns for team-based API keys: + 1. Create team with specific model access pattern + 2. Generate key for that team + 3. Attempt to make completion with test model + 4. Verify access is granted/denied as expected + """ + client = AsyncClient(base_url="http://0.0.0.0:4000") + headers = {"Authorization": "Bearer sk-1234"} + + async with aiohttp.ClientSession() as session: + try: + team_gen = await generate_team(session=session, models=team_models) + print("created team", team_gen) + team_id = team_gen["team_id"] + key_gen = await generate_key(session=session, team_id=team_id) + print("created key", key_gen) + key = key_gen["key"] + response = await mock_chat_completion( + session=session, + key=key, + model=test_model, + ) + if not expect_success: + pytest.fail(f"Expected request to fail for model {test_model}") + assert ( + response is not None + ), "Should get valid response when access is allowed" + except Exception as e: + if expect_success: + pytest.fail(f"Expected request to succeed but got error: {e}") + _validate_model_access_exception( + e, expected_type="team_model_access_denied" + ) + + +@pytest.mark.asyncio +async def test_team_model_access_update(): + """ + Test updating model access for a team: + 1. Create team with restricted model access + 2. Verify access patterns + 3. Update team with new model access + 4. Verify new access patterns + """ + client = AsyncClient(base_url="http://0.0.0.0:4000") + headers = {"Authorization": "Bearer sk-1234"} + + # Create initial team with restricted access + response = await client.post( + "/team/new", + json={"models": ["openai/gpt-4"], "name": "test-team"}, + headers=headers, + ) + assert response.status_code == 200 + team_data = response.json() + team_id = team_data["team_id"] + + # Generate a key for this team + response = await client.post( + "/key/generate", json={"team_id": team_id}, headers=headers + ) + assert response.status_code == 200 + key = response.json()["key"] + + # Test initial access + async with aiohttp.ClientSession() as session: + # Should work with gpt-4 + await mock_chat_completion(session=session, key=key, model="openai/gpt-4") + + # Should fail with gpt-3.5-turbo + with pytest.raises(Exception) as exc_info: + await mock_chat_completion( + session=session, key=key, model="openai/gpt-3.5-turbo" + ) + _validate_model_access_exception( + exc_info.value, expected_type="team_model_access_denied" + ) + + # Update team with new model access + response = await client.post( + "/team/update", + json={"team_id": team_id, "models": ["openai/*"]}, + headers=headers, + ) + assert response.status_code == 200 + + # Test updated access + async with aiohttp.ClientSession() as session: + # Both models should now work + await mock_chat_completion(session=session, key=key, model="openai/gpt-4") + await mock_chat_completion( + session=session, key=key, model="openai/gpt-3.5-turbo" + ) + + # Non-OpenAI model should still fail + with pytest.raises(Exception) as exc_info: + await mock_chat_completion( + session=session, key=key, model="anthropic/claude-2" + ) + _validate_model_access_exception( + exc_info.value, expected_type="team_model_access_denied" + ) + + +def _validate_model_access_exception( + e: Exception, + expected_type: Literal["key_model_access_denied", "team_model_access_denied"], +): + _error_body = e.body + + # Assert error structure and values + assert _error_body["type"] == expected_type + assert _error_body["param"] == "model" + assert _error_body["code"] == "401" + if expected_type == "key_model_access_denied": + assert "API Key not allowed to access model" in _error_body["message"] + elif expected_type == "team_model_access_denied": + assert "Team not allowed to access model" in _error_body["message"] diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index 7dbb9363d5..9afadafd89 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -361,11 +361,9 @@ def test_call_with_invalid_model(prisma_client): asyncio.run(test()) except Exception as e: - assert ( - e.message - == "Authentication Error, API Key not allowed to access model. This token can only access models=['mistral']. Tried to access gemini-pro-vision" - ) - pass + assert isinstance(e, ProxyException) + assert e.type == ProxyErrorTypes.key_model_access_denied + assert e.param == "model" def test_call_with_valid_model(prisma_client): @@ -3200,10 +3198,9 @@ async def test_team_access_groups(prisma_client): pytest.fail(f"This should have failed!. IT's an invalid model") except Exception as e: print("got exception", e) - assert ( - "not allowed to call model" in e.message - and "Allowed team models" in e.message - ) + assert isinstance(e, ProxyException) + assert e.type == ProxyErrorTypes.team_model_access_denied + assert e.param == "model" @pytest.mark.asyncio()