Fix: Potential SQLi in spend_management_endpoints.py (#9878)

* fix: Potential SQLi in spend_management_endpoints.py

* fix tests

* test: add tests for global spend keys endpoint

* chore: update error message

* chore: lint

* chore: rename test
This commit is contained in:
Nilanjan De 2025-04-22 01:29:38 +04:00 committed by GitHub
parent 10257426a2
commit 03245c732a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 81 additions and 9 deletions

View file

@ -1919,9 +1919,7 @@ async def view_spend_logs( # noqa: PLR0915
):
result: dict = {}
for record in response:
dt_object = datetime.strptime(
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore
) # type: ignore
dt_object = datetime.strptime(str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ") # type: ignore
date = dt_object.date()
if date not in result:
result[date] = {"users": {}, "models": {}}
@ -2097,8 +2095,7 @@ async def global_spend_refresh():
try:
resp = await prisma_client.db.query_raw(sql_query)
assert resp[0]["relkind"] == "m"
return True
return resp[0]["relkind"] == "m"
except Exception:
return False
@ -2396,9 +2393,21 @@ async def global_spend_keys(
return response
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
sql_query = f"""SELECT * FROM "Last30dKeysBySpend" LIMIT {limit};"""
sql_query = """SELECT * FROM "Last30dKeysBySpend";"""
response = await prisma_client.db.query_raw(query=sql_query)
if limit is None:
response = await prisma_client.db.query_raw(sql_query)
return response
try:
limit = int(limit)
if limit < 1:
raise ValueError("Limit must be greater than 0")
sql_query = """SELECT * FROM "Last30dKeysBySpend" LIMIT $1 ;"""
response = await prisma_client.db.query_raw(sql_query, limit)
except ValueError as e:
raise HTTPException(
status_code=422, detail={"error": f"Invalid limit: {limit}, error: {e}"}
) from e
return response
@ -2646,9 +2655,9 @@ async def global_spend_models(
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
sql_query = f"""SELECT * FROM "Last30dModelsBySpend" LIMIT {limit};"""
sql_query = """SELECT * FROM "Last30dModelsBySpend" LIMIT $1 ;"""
response = await prisma_client.db.query_raw(query=sql_query)
response = await prisma_client.db.query_raw(sql_query, int(limit))
return response

View file

@ -731,3 +731,66 @@ def _compare_nested_dicts(
f"Value mismatch at {current_path}:\n expected: {expected_str}\n got: {actual_str}"
)
return differences
@pytest.mark.asyncio
async def test_global_spend_keys_endpoint_limit_validation(client, monkeypatch):
"""
Test to ensure that the global_spend_keys endpoint is protected against SQL injection attacks.
Verifies that the limit parameter is properly parameterized and not directly interpolated.
"""
# Create a simple mock for prisma client with empty response
mock_prisma_client = MagicMock()
mock_db = MagicMock()
mock_query_raw = MagicMock()
mock_query_raw.return_value = asyncio.Future()
mock_query_raw.return_value.set_result([])
mock_db.query_raw = mock_query_raw
mock_prisma_client.db = mock_db
# Apply the mock to the prisma_client module
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
# Call the endpoint without specifying a limit
no_limit_response = client.get("/global/spend/keys")
assert no_limit_response.status_code == 200
mock_query_raw.assert_called_once_with('SELECT * FROM "Last30dKeysBySpend";')
# Reset the mock for the next test
mock_query_raw.reset_mock()
# Test with valid input
normal_limit = "10"
good_input_response = client.get(f"/global/spend/keys?limit={normal_limit}")
assert good_input_response.status_code == 200
# Verify the mock was called with the correct parameters
mock_query_raw.assert_called_once_with(
'SELECT * FROM "Last30dKeysBySpend" LIMIT $1 ;', 10
)
# Reset the mock for the next test
mock_query_raw.reset_mock()
# Test with SQL injection payload
sql_injection_limit = "10; DROP TABLE spend_logs; --"
response = client.get(f"/global/spend/keys?limit={sql_injection_limit}")
# Verify the response is a validation error (422)
assert response.status_code == 422
# Verify the mock was not called with the SQL injection payload
# This confirms that the validation happens before the database query
mock_query_raw.assert_not_called()
# Reset the mock for the next test
mock_query_raw.reset_mock()
# Test with non-numeric input
non_numeric_limit = "abc"
response = client.get(f"/global/spend/keys?limit={non_numeric_limit}")
assert response.status_code == 422
mock_query_raw.assert_not_called()
mock_query_raw.reset_mock()
# Test with negative number
negative_limit = "-5"
response = client.get(f"/global/spend/keys?limit={negative_limit}")
assert response.status_code == 422
mock_query_raw.assert_not_called()
mock_query_raw.reset_mock()
# Test with zero
zero_limit = "0"
response = client.get(f"/global/spend/keys?limit={zero_limit}")
assert response.status_code == 422
mock_query_raw.assert_not_called()
mock_query_raw.reset_mock()