mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
(security fix) - Enforce model access restrictions on Azure OpenAI route (#8888)
* fix(user_api_key_auth.py): Fixes https://github.com/BerriAI/litellm/issues/8780 security fix - enforce model access checks on azure routes * test(test_user_api_key_auth.py): add unit testing * test(test_openai_endpoints.py): add e2e test to ensure azure routes also run through model validation checks
This commit is contained in:
parent
2a3b70f2b6
commit
740bd7e9ce
3 changed files with 56 additions and 8 deletions
|
@ -8,6 +8,7 @@ Returns a UserAPIKeyAuth object if the API key is valid
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, cast
|
||||
|
@ -279,6 +280,21 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
|
|||
return LitellmUserRoles.TEAM
|
||||
|
||||
|
||||
def get_model_from_request(request_data: dict, route: str) -> Optional[str]:
|
||||
|
||||
# First try to get model from request_data
|
||||
model = request_data.get("model")
|
||||
|
||||
# If model not in request_data, try to extract from route
|
||||
if model is None:
|
||||
# Parse model from route that follows the pattern /openai/deployments/{model}/*
|
||||
match = re.match(r"/openai/deployments/([^/]+)", route)
|
||||
if match:
|
||||
model = match.group(1)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
request: Request,
|
||||
api_key: str,
|
||||
|
@ -807,7 +823,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
# the validation will occur when checking the team has access to this model
|
||||
pass
|
||||
else:
|
||||
model = request_data.get("model", None)
|
||||
model = get_model_from_request(request_data, route)
|
||||
fallback_models = cast(
|
||||
Optional[List[ALL_FALLBACK_MODEL_VALUES]],
|
||||
request_data.get("fallbacks", None),
|
||||
|
|
|
@ -930,3 +930,20 @@ def test_can_rbac_role_call_model_no_role_permissions():
|
|||
general_settings={"role_permissions": []},
|
||||
model="anthropic-claude",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"route, request_data, expected_model",
|
||||
[
|
||||
("/v1/chat/completions", {"model": "gpt-4"}, "gpt-4"),
|
||||
("/v1/completions", {"model": "gpt-4"}, "gpt-4"),
|
||||
("/v1/chat/completions", {}, None),
|
||||
("/v1/completions", {}, None),
|
||||
("/openai/deployments/gpt-4", {}, "gpt-4"),
|
||||
("/openai/deployments/gpt-4", {"model": "gpt-4o"}, "gpt-4o"),
|
||||
],
|
||||
)
|
||||
def test_get_model_from_request(route, request_data, expected_model):
|
||||
from litellm.proxy.auth.user_api_key_auth import get_model_from_request
|
||||
|
||||
assert get_model_from_request(request_data, route) == expected_model
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
import aiohttp, openai
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
|
||||
from typing import Optional, List, Union
|
||||
import uuid
|
||||
|
||||
|
@ -201,6 +201,14 @@ async def chat_completion_with_headers(session, key, model="gpt-4"):
|
|||
return raw_headers_json
|
||||
|
||||
|
||||
async def chat_completion_with_model_from_route(session, key, route):
|
||||
url = "http://0.0.0.0:4000/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
async def completion(session, key):
|
||||
url = "http://0.0.0.0:4000/completions"
|
||||
headers = {
|
||||
|
@ -288,12 +296,19 @@ async def test_chat_completion():
|
|||
make chat completion call
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
key_gen = await generate_key(session=session)
|
||||
key = key_gen["key"]
|
||||
await chat_completion(session=session, key=key)
|
||||
key_gen = await new_user(session=session)
|
||||
key_2 = key_gen["key"]
|
||||
await chat_completion(session=session, key=key_2)
|
||||
key_gen = await generate_key(session=session, models=["gpt-3.5-turbo"])
|
||||
azure_client = AsyncAzureOpenAI(
|
||||
azure_endpoint="http://0.0.0.0:4000",
|
||||
azure_deployment="random-model",
|
||||
api_key=key_gen["key"],
|
||||
api_version="2024-02-15-preview",
|
||||
)
|
||||
with pytest.raises(openai.AuthenticationError) as e:
|
||||
response = await azure_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
)
|
||||
assert "API Key not allowed to access model." in str(e)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue