mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Fix: Potential SQLi in spend_management_endpoints.py (#9878)
* fix: Potential SQLi in spend_management_endpoints.py * fix tests * test: add tests for global spend keys endpoint * chore: update error message * chore: lint * chore: rename test
This commit is contained in:
parent
10257426a2
commit
03245c732a
2 changed files with 81 additions and 9 deletions
|
@ -1919,9 +1919,7 @@ async def view_spend_logs( # noqa: PLR0915
|
|||
):
|
||||
result: dict = {}
|
||||
for record in response:
|
||||
dt_object = datetime.strptime(
|
||||
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore
|
||||
) # type: ignore
|
||||
dt_object = datetime.strptime(str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ") # type: ignore
|
||||
date = dt_object.date()
|
||||
if date not in result:
|
||||
result[date] = {"users": {}, "models": {}}
|
||||
|
@ -2097,8 +2095,7 @@ async def global_spend_refresh():
|
|||
try:
|
||||
resp = await prisma_client.db.query_raw(sql_query)
|
||||
|
||||
assert resp[0]["relkind"] == "m"
|
||||
return True
|
||||
return resp[0]["relkind"] == "m"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
@ -2396,9 +2393,21 @@ async def global_spend_keys(
|
|||
return response
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
sql_query = f"""SELECT * FROM "Last30dKeysBySpend" LIMIT {limit};"""
|
||||
sql_query = """SELECT * FROM "Last30dKeysBySpend";"""
|
||||
|
||||
response = await prisma_client.db.query_raw(query=sql_query)
|
||||
if limit is None:
|
||||
response = await prisma_client.db.query_raw(sql_query)
|
||||
return response
|
||||
try:
|
||||
limit = int(limit)
|
||||
if limit < 1:
|
||||
raise ValueError("Limit must be greater than 0")
|
||||
sql_query = """SELECT * FROM "Last30dKeysBySpend" LIMIT $1 ;"""
|
||||
response = await prisma_client.db.query_raw(sql_query, limit)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=422, detail={"error": f"Invalid limit: {limit}, error: {e}"}
|
||||
) from e
|
||||
|
||||
return response
|
||||
|
||||
|
@ -2646,9 +2655,9 @@ async def global_spend_models(
|
|||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
sql_query = f"""SELECT * FROM "Last30dModelsBySpend" LIMIT {limit};"""
|
||||
sql_query = """SELECT * FROM "Last30dModelsBySpend" LIMIT $1 ;"""
|
||||
|
||||
response = await prisma_client.db.query_raw(query=sql_query)
|
||||
response = await prisma_client.db.query_raw(sql_query, int(limit))
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
@ -731,3 +731,66 @@ def _compare_nested_dicts(
|
|||
f"Value mismatch at {current_path}:\n expected: {expected_str}\n got: {actual_str}"
|
||||
)
|
||||
return differences
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_spend_keys_endpoint_limit_validation(client, monkeypatch):
|
||||
"""
|
||||
Test to ensure that the global_spend_keys endpoint is protected against SQL injection attacks.
|
||||
Verifies that the limit parameter is properly parameterized and not directly interpolated.
|
||||
"""
|
||||
# Create a simple mock for prisma client with empty response
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_db = MagicMock()
|
||||
mock_query_raw = MagicMock()
|
||||
mock_query_raw.return_value = asyncio.Future()
|
||||
mock_query_raw.return_value.set_result([])
|
||||
mock_db.query_raw = mock_query_raw
|
||||
mock_prisma_client.db = mock_db
|
||||
# Apply the mock to the prisma_client module
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
|
||||
# Call the endpoint without specifying a limit
|
||||
no_limit_response = client.get("/global/spend/keys")
|
||||
assert no_limit_response.status_code == 200
|
||||
mock_query_raw.assert_called_once_with('SELECT * FROM "Last30dKeysBySpend";')
|
||||
# Reset the mock for the next test
|
||||
mock_query_raw.reset_mock()
|
||||
# Test with valid input
|
||||
normal_limit = "10"
|
||||
good_input_response = client.get(f"/global/spend/keys?limit={normal_limit}")
|
||||
assert good_input_response.status_code == 200
|
||||
# Verify the mock was called with the correct parameters
|
||||
mock_query_raw.assert_called_once_with(
|
||||
'SELECT * FROM "Last30dKeysBySpend" LIMIT $1 ;', 10
|
||||
)
|
||||
# Reset the mock for the next test
|
||||
mock_query_raw.reset_mock()
|
||||
# Test with SQL injection payload
|
||||
sql_injection_limit = "10; DROP TABLE spend_logs; --"
|
||||
response = client.get(f"/global/spend/keys?limit={sql_injection_limit}")
|
||||
# Verify the response is a validation error (422)
|
||||
assert response.status_code == 422
|
||||
# Verify the mock was not called with the SQL injection payload
|
||||
# This confirms that the validation happens before the database query
|
||||
mock_query_raw.assert_not_called()
|
||||
# Reset the mock for the next test
|
||||
mock_query_raw.reset_mock()
|
||||
# Test with non-numeric input
|
||||
non_numeric_limit = "abc"
|
||||
response = client.get(f"/global/spend/keys?limit={non_numeric_limit}")
|
||||
assert response.status_code == 422
|
||||
mock_query_raw.assert_not_called()
|
||||
mock_query_raw.reset_mock()
|
||||
# Test with negative number
|
||||
negative_limit = "-5"
|
||||
response = client.get(f"/global/spend/keys?limit={negative_limit}")
|
||||
assert response.status_code == 422
|
||||
mock_query_raw.assert_not_called()
|
||||
mock_query_raw.reset_mock()
|
||||
# Test with zero
|
||||
zero_limit = "0"
|
||||
response = client.get(f"/global/spend/keys?limit={zero_limit}")
|
||||
assert response.status_code == 422
|
||||
mock_query_raw.assert_not_called()
|
||||
mock_query_raw.reset_mock()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue