(security fix) - Enforce model access restrictions on Azure OpenAI route (#8888)

* fix(user_api_key_auth.py): Fixes https://github.com/BerriAI/litellm/issues/8780

security fix - enforce model access checks on azure routes

* test(test_user_api_key_auth.py): add unit testing

* test(test_openai_endpoints.py): add e2e test to ensure azure routes also run through model validation checks
This commit is contained in:
Krish Dholakia 2025-02-27 21:24:58 -08:00 committed by GitHub
parent 2a3b70f2b6
commit 740bd7e9ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 56 additions and 8 deletions

View file

@ -8,6 +8,7 @@ Returns a UserAPIKeyAuth object if the API key is valid
""" """
import asyncio import asyncio
import re
import secrets import secrets
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, cast from typing import Optional, cast
@ -279,6 +280,21 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
return LitellmUserRoles.TEAM return LitellmUserRoles.TEAM
def get_model_from_request(request_data: dict, route: str) -> Optional[str]:
# First try to get model from request_data
model = request_data.get("model")
# If model not in request_data, try to extract from route
if model is None:
# Parse model from route that follows the pattern /openai/deployments/{model}/*
match = re.match(r"/openai/deployments/([^/]+)", route)
if match:
model = match.group(1)
return model
async def _user_api_key_auth_builder( # noqa: PLR0915 async def _user_api_key_auth_builder( # noqa: PLR0915
request: Request, request: Request,
api_key: str, api_key: str,
@ -807,7 +823,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
# the validation will occur when checking the team has access to this model # the validation will occur when checking the team has access to this model
pass pass
else: else:
model = request_data.get("model", None) model = get_model_from_request(request_data, route)
fallback_models = cast( fallback_models = cast(
Optional[List[ALL_FALLBACK_MODEL_VALUES]], Optional[List[ALL_FALLBACK_MODEL_VALUES]],
request_data.get("fallbacks", None), request_data.get("fallbacks", None),

View file

@ -930,3 +930,20 @@ def test_can_rbac_role_call_model_no_role_permissions():
general_settings={"role_permissions": []}, general_settings={"role_permissions": []},
model="anthropic-claude", model="anthropic-claude",
) )
@pytest.mark.parametrize(
"route, request_data, expected_model",
[
("/v1/chat/completions", {"model": "gpt-4"}, "gpt-4"),
("/v1/completions", {"model": "gpt-4"}, "gpt-4"),
("/v1/chat/completions", {}, None),
("/v1/completions", {}, None),
("/openai/deployments/gpt-4", {}, "gpt-4"),
("/openai/deployments/gpt-4", {"model": "gpt-4o"}, "gpt-4o"),
],
)
def test_get_model_from_request(route, request_data, expected_model):
from litellm.proxy.auth.user_api_key_auth import get_model_from_request
assert get_model_from_request(request_data, route) == expected_model

View file

@ -3,7 +3,7 @@
import pytest import pytest
import asyncio import asyncio
import aiohttp, openai import aiohttp, openai
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
from typing import Optional, List, Union from typing import Optional, List, Union
import uuid import uuid
@ -201,6 +201,14 @@ async def chat_completion_with_headers(session, key, model="gpt-4"):
return raw_headers_json return raw_headers_json
async def chat_completion_with_model_from_route(session, key, route):
url = "http://0.0.0.0:4000/chat/completions"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
async def completion(session, key): async def completion(session, key):
url = "http://0.0.0.0:4000/completions" url = "http://0.0.0.0:4000/completions"
headers = { headers = {
@ -288,12 +296,19 @@ async def test_chat_completion():
make chat completion call make chat completion call
""" """
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
key_gen = await generate_key(session=session) key_gen = await generate_key(session=session, models=["gpt-3.5-turbo"])
key = key_gen["key"] azure_client = AsyncAzureOpenAI(
await chat_completion(session=session, key=key) azure_endpoint="http://0.0.0.0:4000",
key_gen = await new_user(session=session) azure_deployment="random-model",
key_2 = key_gen["key"] api_key=key_gen["key"],
await chat_completion(session=session, key=key_2) api_version="2024-02-15-preview",
)
with pytest.raises(openai.AuthenticationError) as e:
response = await azure_client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello!"}],
)
assert "API Key not allowed to access model." in str(e)
@pytest.mark.asyncio @pytest.mark.asyncio