(proxy) - Auth fix, ensure re-using safe request body for checking model field (#7222)

* litellm fix auth check

* fix _read_request_body

* test_auth_with_form_data_and_model

* fix auth check

* fix _read_request_body

* fix _safe_get_request_headers
This commit is contained in:
Ishaan Jaff 2024-12-14 12:01:25 -08:00 committed by GitHub
parent ec36353b41
commit 9432812c90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 87 additions and 16 deletions

View file

@ -899,12 +899,10 @@ async def user_api_key_auth( # noqa: PLR0915
# the validation will occur when checking the team has access to this model # the validation will occur when checking the team has access to this model
pass pass
else: else:
try: model = request_data.get("model", None)
data = await request.json() fallback_models: Optional[List[str]] = request_data.get(
except json.JSONDecodeError: "fallbacks", None
data = {} # Provide a default value, such as an empty dictionary )
model = data.get("model", None)
fallback_models: Optional[List[str]] = data.get("fallbacks", None)
if model is not None: if model is not None:
await can_key_call_model( await can_key_call_model(

View file

@ -21,19 +21,23 @@ async def _read_request_body(request: Optional[Request]) -> Dict:
try: try:
if request is None: if request is None:
return {} return {}
_request_headers: dict = _safe_get_request_headers(request=request)
content_type = _request_headers.get("content-type", "")
if "form" in content_type:
return dict(await request.form())
else:
# Read the request body
body = await request.body()
# Read the request body # Return empty dict if body is empty or None
body = await request.body() if not body:
return {}
# Return empty dict if body is empty or None # Decode the body to a string
if not body: body_str = body.decode()
return {}
# Decode the body to a string # Attempt JSON parsing (safe for untrusted input)
body_str = body.decode() return json.loads(body_str)
# Attempt JSON parsing (safe for untrusted input)
return json.loads(body_str)
except json.JSONDecodeError: except json.JSONDecodeError:
# Log detailed information for debugging # Log detailed information for debugging
@ -48,6 +52,21 @@ async def _read_request_body(request: Optional[Request]) -> Dict:
return {} return {}
def _safe_get_request_headers(request: Optional[Request]) -> dict:
"""
[Non-Blocking] Safely get the request headers
"""
try:
if request is None:
return {}
return dict(request.headers)
except Exception as e:
verbose_proxy_logger.exception(
"Unexpected error reading request headers - {}".format(e)
)
return {}
def check_file_size_under_limit( def check_file_size_under_limit(
request_data: dict, request_data: dict,
file: UploadFile, file: UploadFile,

View file

@ -496,3 +496,57 @@ def test_read_request_body():
request.body = return_body request.body = return_body
result = _read_request_body(request) result = _read_request_body(request)
assert result is not None assert result is not None
@pytest.mark.asyncio
async def test_auth_with_form_data_and_model():
"""
Test user_api_key_auth when:
1. Request has form data instead of JSON body
2. Virtual key has a model set
"""
from fastapi import Request
from starlette.datastructures import URL, FormData
from litellm.proxy.proxy_server import (
hash_token,
user_api_key_cache,
user_api_key_auth,
)
# Setup
user_key = "sk-12345678"
# Create a virtual key with a specific model
valid_token = UserAPIKeyAuth(
token=hash_token(user_key),
models=["gpt-4"],
)
# Store the virtual key in cache
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
# Create request with form data
request = Request(
scope={
"type": "http",
"method": "POST",
"headers": [(b"content-type", b"application/x-www-form-urlencoded")],
}
)
request._url = URL(url="/chat/completions")
# Mock form data
form_data = FormData([("key1", "value1"), ("key2", "value2")])
async def return_form_data():
return form_data
request.form = return_form_data
# Test user_api_key_auth with form data request
response = await user_api_key_auth(request=request, api_key="Bearer " + user_key)
assert response.models == ["gpt-4"], "Model from virtual key should be preserved"