mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(security fix) - Enforce model access restrictions on Azure OpenAI route (#8888)
* fix(user_api_key_auth.py): Fixes https://github.com/BerriAI/litellm/issues/8780 security fix - enforce model access checks on azure routes * test(test_user_api_key_auth.py): add unit testing * test(test_openai_endpoints.py): add e2e test to ensure azure routes also run through model validation checks
This commit is contained in:
parent
2a3b70f2b6
commit
740bd7e9ce
3 changed files with 56 additions and 8 deletions
|
@ -8,6 +8,7 @@ Returns a UserAPIKeyAuth object if the API key is valid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
@ -279,6 +280,21 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
|
||||||
return LitellmUserRoles.TEAM
|
return LitellmUserRoles.TEAM
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_from_request(request_data: dict, route: str) -> Optional[str]:
|
||||||
|
|
||||||
|
# First try to get model from request_data
|
||||||
|
model = request_data.get("model")
|
||||||
|
|
||||||
|
# If model not in request_data, try to extract from route
|
||||||
|
if model is None:
|
||||||
|
# Parse model from route that follows the pattern /openai/deployments/{model}/*
|
||||||
|
match = re.match(r"/openai/deployments/([^/]+)", route)
|
||||||
|
if match:
|
||||||
|
model = match.group(1)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
async def _user_api_key_auth_builder( # noqa: PLR0915
|
async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
@ -807,7 +823,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
# the validation will occur when checking the team has access to this model
|
# the validation will occur when checking the team has access to this model
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
model = request_data.get("model", None)
|
model = get_model_from_request(request_data, route)
|
||||||
fallback_models = cast(
|
fallback_models = cast(
|
||||||
Optional[List[ALL_FALLBACK_MODEL_VALUES]],
|
Optional[List[ALL_FALLBACK_MODEL_VALUES]],
|
||||||
request_data.get("fallbacks", None),
|
request_data.get("fallbacks", None),
|
||||||
|
|
|
@ -930,3 +930,20 @@ def test_can_rbac_role_call_model_no_role_permissions():
|
||||||
general_settings={"role_permissions": []},
|
general_settings={"role_permissions": []},
|
||||||
model="anthropic-claude",
|
model="anthropic-claude",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"route, request_data, expected_model",
|
||||||
|
[
|
||||||
|
("/v1/chat/completions", {"model": "gpt-4"}, "gpt-4"),
|
||||||
|
("/v1/completions", {"model": "gpt-4"}, "gpt-4"),
|
||||||
|
("/v1/chat/completions", {}, None),
|
||||||
|
("/v1/completions", {}, None),
|
||||||
|
("/openai/deployments/gpt-4", {}, "gpt-4"),
|
||||||
|
("/openai/deployments/gpt-4", {"model": "gpt-4o"}, "gpt-4o"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_model_from_request(route, request_data, expected_model):
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import get_model_from_request
|
||||||
|
|
||||||
|
assert get_model_from_request(request_data, route) == expected_model
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp, openai
|
import aiohttp, openai
|
||||||
from openai import OpenAI, AsyncOpenAI
|
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
@ -201,6 +201,14 @@ async def chat_completion_with_headers(session, key, model="gpt-4"):
|
||||||
return raw_headers_json
|
return raw_headers_json
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion_with_model_from_route(session, key, route):
|
||||||
|
url = "http://0.0.0.0:4000/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def completion(session, key):
|
async def completion(session, key):
|
||||||
url = "http://0.0.0.0:4000/completions"
|
url = "http://0.0.0.0:4000/completions"
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -288,12 +296,19 @@ async def test_chat_completion():
|
||||||
make chat completion call
|
make chat completion call
|
||||||
"""
|
"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
key_gen = await generate_key(session=session)
|
key_gen = await generate_key(session=session, models=["gpt-3.5-turbo"])
|
||||||
key = key_gen["key"]
|
azure_client = AsyncAzureOpenAI(
|
||||||
await chat_completion(session=session, key=key)
|
azure_endpoint="http://0.0.0.0:4000",
|
||||||
key_gen = await new_user(session=session)
|
azure_deployment="random-model",
|
||||||
key_2 = key_gen["key"]
|
api_key=key_gen["key"],
|
||||||
await chat_completion(session=session, key=key_2)
|
api_version="2024-02-15-preview",
|
||||||
|
)
|
||||||
|
with pytest.raises(openai.AuthenticationError) as e:
|
||||||
|
response = await azure_client.chat.completions.create(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello!"}],
|
||||||
|
)
|
||||||
|
assert "API Key not allowed to access model." in str(e)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue