mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(proxy) - Auth fix, ensure re-using safe request body for checking model
field (#7222)
* litellm fix auth check * fix _read_request_body * test_auth_with_form_data_and_model * fix auth check * fix _read_request_body * fix _safe_get_request_headers
This commit is contained in:
parent
ec36353b41
commit
9432812c90
3 changed files with 87 additions and 16 deletions
|
@ -899,12 +899,10 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
# the validation will occur when checking the team has access to this model
|
# the validation will occur when checking the team has access to this model
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
try:
|
model = request_data.get("model", None)
|
||||||
data = await request.json()
|
fallback_models: Optional[List[str]] = request_data.get(
|
||||||
except json.JSONDecodeError:
|
"fallbacks", None
|
||||||
data = {} # Provide a default value, such as an empty dictionary
|
)
|
||||||
model = data.get("model", None)
|
|
||||||
fallback_models: Optional[List[str]] = data.get("fallbacks", None)
|
|
||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
await can_key_call_model(
|
await can_key_call_model(
|
||||||
|
|
|
@ -21,7 +21,11 @@ async def _read_request_body(request: Optional[Request]) -> Dict:
|
||||||
try:
|
try:
|
||||||
if request is None:
|
if request is None:
|
||||||
return {}
|
return {}
|
||||||
|
_request_headers: dict = _safe_get_request_headers(request=request)
|
||||||
|
content_type = _request_headers.get("content-type", "")
|
||||||
|
if "form" in content_type:
|
||||||
|
return dict(await request.form())
|
||||||
|
else:
|
||||||
# Read the request body
|
# Read the request body
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|
||||||
|
@ -48,6 +52,21 @@ async def _read_request_body(request: Optional[Request]) -> Dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_get_request_headers(request: Optional[Request]) -> dict:
|
||||||
|
"""
|
||||||
|
[Non-Blocking] Safely get the request headers
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if request is None:
|
||||||
|
return {}
|
||||||
|
return dict(request.headers)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"Unexpected error reading request headers - {}".format(e)
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def check_file_size_under_limit(
|
def check_file_size_under_limit(
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
file: UploadFile,
|
file: UploadFile,
|
||||||
|
|
|
@ -496,3 +496,57 @@ def test_read_request_body():
|
||||||
request.body = return_body
|
request.body = return_body
|
||||||
result = _read_request_body(request)
|
result = _read_request_body(request)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_with_form_data_and_model():
|
||||||
|
"""
|
||||||
|
Test user_api_key_auth when:
|
||||||
|
1. Request has form data instead of JSON body
|
||||||
|
2. Virtual key has a model set
|
||||||
|
"""
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.datastructures import URL, FormData
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
hash_token,
|
||||||
|
user_api_key_cache,
|
||||||
|
user_api_key_auth,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
user_key = "sk-12345678"
|
||||||
|
|
||||||
|
# Create a virtual key with a specific model
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
token=hash_token(user_key),
|
||||||
|
models=["gpt-4"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the virtual key in cache
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
|
||||||
|
|
||||||
|
# Create request with form data
|
||||||
|
request = Request(
|
||||||
|
scope={
|
||||||
|
"type": "http",
|
||||||
|
"method": "POST",
|
||||||
|
"headers": [(b"content-type", b"application/x-www-form-urlencoded")],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
# Mock form data
|
||||||
|
form_data = FormData([("key1", "value1"), ("key2", "value2")])
|
||||||
|
|
||||||
|
async def return_form_data():
|
||||||
|
return form_data
|
||||||
|
|
||||||
|
request.form = return_form_data
|
||||||
|
|
||||||
|
# Test user_api_key_auth with form data request
|
||||||
|
response = await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||||
|
assert response.models == ["gpt-4"], "Model from virtual key should be preserved"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue