From 475c1d0f99ebed6ee3a4692e30e2fb655a00ee3b Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 27 Feb 2025 21:24:58 -0800 Subject: [PATCH] (security fix) - Enforce model access restrictions on Azure OpenAI route (#8888) * fix(user_api_key_auth.py): Fixes https://github.com/BerriAI/litellm/issues/8780 security fix - enforce model access checks on azure routes * test(test_user_api_key_auth.py): add unit testing * test(test_openai_endpoints.py): add e2e test to ensure azure routes also run through model validation checks --- litellm/proxy/auth/user_api_key_auth.py | 18 +++++++++++- .../test_user_api_key_auth.py | 17 +++++++++++ tests/test_openai_endpoints.py | 29 ++++++++++++++----- 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 5b5cb038e0..ecefc64d67 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -8,6 +8,7 @@ Returns a UserAPIKeyAuth object if the API key is valid """ import asyncio +import re import secrets from datetime import datetime, timezone from typing import Optional, cast @@ -279,6 +280,21 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str: return LitellmUserRoles.TEAM +def get_model_from_request(request_data: dict, route: str) -> Optional[str]: + + # First try to get model from request_data + model = request_data.get("model") + + # If model not in request_data, try to extract from route + if model is None: + # Parse model from route that follows the pattern /openai/deployments/{model}/* + match = re.match(r"/openai/deployments/([^/]+)", route) + if match: + model = match.group(1) + + return model + + async def _user_api_key_auth_builder( # noqa: PLR0915 request: Request, api_key: str, @@ -807,7 +823,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 # the validation will occur when checking the team has access to this model pass else: - model = request_data.get("model", None) + model = get_model_from_request(request_data, route) fallback_models = cast( Optional[List[ALL_FALLBACK_MODEL_VALUES]], request_data.get("fallbacks", None), diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index 141a7cbaad..dbe49a560d 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -930,3 +930,20 @@ def test_can_rbac_role_call_model_no_role_permissions(): general_settings={"role_permissions": []}, model="anthropic-claude", ) + + +@pytest.mark.parametrize( + "route, request_data, expected_model", + [ + ("/v1/chat/completions", {"model": "gpt-4"}, "gpt-4"), + ("/v1/completions", {"model": "gpt-4"}, "gpt-4"), + ("/v1/chat/completions", {}, None), + ("/v1/completions", {}, None), + ("/openai/deployments/gpt-4", {}, "gpt-4"), + ("/openai/deployments/gpt-4", {"model": "gpt-4o"}, "gpt-4o"), + ], +) +def test_get_model_from_request(route, request_data, expected_model): + from litellm.proxy.auth.user_api_key_auth import get_model_from_request + + assert get_model_from_request(request_data, route) == expected_model diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 0faae9d333..45fd29721f 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -3,7 +3,7 @@ import pytest import asyncio import aiohttp, openai -from openai import OpenAI, AsyncOpenAI +from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI from typing import Optional, List, Union import uuid @@ -201,6 +201,14 @@ async def chat_completion_with_headers(session, key, model="gpt-4"): return raw_headers_json +async def chat_completion_with_model_from_route(session, key, route): + url = "http://0.0.0.0:4000/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + + async def completion(session, key): url = "http://0.0.0.0:4000/completions" headers = { @@ -288,12 +296,19 @@ async def test_chat_completion(): make chat completion call """ async with aiohttp.ClientSession() as session: - key_gen = await generate_key(session=session) - key = key_gen["key"] - await chat_completion(session=session, key=key) - key_gen = await new_user(session=session) - key_2 = key_gen["key"] - await chat_completion(session=session, key=key_2) + key_gen = await generate_key(session=session, models=["gpt-3.5-turbo"]) + azure_client = AsyncAzureOpenAI( + azure_endpoint="http://0.0.0.0:4000", + azure_deployment="random-model", + api_key=key_gen["key"], + api_version="2024-02-15-preview", + ) + with pytest.raises(openai.AuthenticationError) as e: + response = await azure_client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": "Hello!"}], + ) + assert "API Key not allowed to access model." in str(e) @pytest.mark.asyncio