mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Fix: Potential SQLi in spend_management_endpoints.py (#9878)
* fix: Potential SQLi in spend_management_endpoints.py * fix tests * test: add tests for global spend keys endpoint * chore: update error message * chore: lint * chore: rename test
This commit is contained in:
parent
a02f9f6f2c
commit
28467252c0
2 changed files with 81 additions and 9 deletions
|
@ -1919,9 +1919,7 @@ async def view_spend_logs( # noqa: PLR0915
|
||||||
):
|
):
|
||||||
result: dict = {}
|
result: dict = {}
|
||||||
for record in response:
|
for record in response:
|
||||||
dt_object = datetime.strptime(
|
dt_object = datetime.strptime(str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ") # type: ignore
|
||||||
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore
|
|
||||||
) # type: ignore
|
|
||||||
date = dt_object.date()
|
date = dt_object.date()
|
||||||
if date not in result:
|
if date not in result:
|
||||||
result[date] = {"users": {}, "models": {}}
|
result[date] = {"users": {}, "models": {}}
|
||||||
|
@ -2097,8 +2095,7 @@ async def global_spend_refresh():
|
||||||
try:
|
try:
|
||||||
resp = await prisma_client.db.query_raw(sql_query)
|
resp = await prisma_client.db.query_raw(sql_query)
|
||||||
|
|
||||||
assert resp[0]["relkind"] == "m"
|
return resp[0]["relkind"] == "m"
|
||||||
return True
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -2396,9 +2393,21 @@ async def global_spend_keys(
|
||||||
return response
|
return response
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||||
sql_query = f"""SELECT * FROM "Last30dKeysBySpend" LIMIT {limit};"""
|
sql_query = """SELECT * FROM "Last30dKeysBySpend";"""
|
||||||
|
|
||||||
response = await prisma_client.db.query_raw(query=sql_query)
|
if limit is None:
|
||||||
|
response = await prisma_client.db.query_raw(sql_query)
|
||||||
|
return response
|
||||||
|
try:
|
||||||
|
limit = int(limit)
|
||||||
|
if limit < 1:
|
||||||
|
raise ValueError("Limit must be greater than 0")
|
||||||
|
sql_query = """SELECT * FROM "Last30dKeysBySpend" LIMIT $1 ;"""
|
||||||
|
response = await prisma_client.db.query_raw(sql_query, limit)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422, detail={"error": f"Invalid limit: {limit}, error: {e}"}
|
||||||
|
) from e
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -2646,9 +2655,9 @@ async def global_spend_models(
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||||
|
|
||||||
sql_query = f"""SELECT * FROM "Last30dModelsBySpend" LIMIT {limit};"""
|
sql_query = """SELECT * FROM "Last30dModelsBySpend" LIMIT $1 ;"""
|
||||||
|
|
||||||
response = await prisma_client.db.query_raw(query=sql_query)
|
response = await prisma_client.db.query_raw(sql_query, int(limit))
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
@ -731,3 +731,66 @@ def _compare_nested_dicts(
|
||||||
f"Value mismatch at {current_path}:\n expected: {expected_str}\n got: {actual_str}"
|
f"Value mismatch at {current_path}:\n expected: {expected_str}\n got: {actual_str}"
|
||||||
)
|
)
|
||||||
return differences
|
return differences
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_global_spend_keys_endpoint_limit_validation(client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test to ensure that the global_spend_keys endpoint is protected against SQL injection attacks.
|
||||||
|
Verifies that the limit parameter is properly parameterized and not directly interpolated.
|
||||||
|
"""
|
||||||
|
# Create a simple mock for prisma client with empty response
|
||||||
|
mock_prisma_client = MagicMock()
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_query_raw = MagicMock()
|
||||||
|
mock_query_raw.return_value = asyncio.Future()
|
||||||
|
mock_query_raw.return_value.set_result([])
|
||||||
|
mock_db.query_raw = mock_query_raw
|
||||||
|
mock_prisma_client.db = mock_db
|
||||||
|
# Apply the mock to the prisma_client module
|
||||||
|
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||||
|
|
||||||
|
# Call the endpoint without specifying a limit
|
||||||
|
no_limit_response = client.get("/global/spend/keys")
|
||||||
|
assert no_limit_response.status_code == 200
|
||||||
|
mock_query_raw.assert_called_once_with('SELECT * FROM "Last30dKeysBySpend";')
|
||||||
|
# Reset the mock for the next test
|
||||||
|
mock_query_raw.reset_mock()
|
||||||
|
# Test with valid input
|
||||||
|
normal_limit = "10"
|
||||||
|
good_input_response = client.get(f"/global/spend/keys?limit={normal_limit}")
|
||||||
|
assert good_input_response.status_code == 200
|
||||||
|
# Verify the mock was called with the correct parameters
|
||||||
|
mock_query_raw.assert_called_once_with(
|
||||||
|
'SELECT * FROM "Last30dKeysBySpend" LIMIT $1 ;', 10
|
||||||
|
)
|
||||||
|
# Reset the mock for the next test
|
||||||
|
mock_query_raw.reset_mock()
|
||||||
|
# Test with SQL injection payload
|
||||||
|
sql_injection_limit = "10; DROP TABLE spend_logs; --"
|
||||||
|
response = client.get(f"/global/spend/keys?limit={sql_injection_limit}")
|
||||||
|
# Verify the response is a validation error (422)
|
||||||
|
assert response.status_code == 422
|
||||||
|
# Verify the mock was not called with the SQL injection payload
|
||||||
|
# This confirms that the validation happens before the database query
|
||||||
|
mock_query_raw.assert_not_called()
|
||||||
|
# Reset the mock for the next test
|
||||||
|
mock_query_raw.reset_mock()
|
||||||
|
# Test with non-numeric input
|
||||||
|
non_numeric_limit = "abc"
|
||||||
|
response = client.get(f"/global/spend/keys?limit={non_numeric_limit}")
|
||||||
|
assert response.status_code == 422
|
||||||
|
mock_query_raw.assert_not_called()
|
||||||
|
mock_query_raw.reset_mock()
|
||||||
|
# Test with negative number
|
||||||
|
negative_limit = "-5"
|
||||||
|
response = client.get(f"/global/spend/keys?limit={negative_limit}")
|
||||||
|
assert response.status_code == 422
|
||||||
|
mock_query_raw.assert_not_called()
|
||||||
|
mock_query_raw.reset_mock()
|
||||||
|
# Test with zero
|
||||||
|
zero_limit = "0"
|
||||||
|
response = client.get(f"/global/spend/keys?limit={zero_limit}")
|
||||||
|
assert response.status_code == 422
|
||||||
|
mock_query_raw.assert_not_called()
|
||||||
|
mock_query_raw.reset_mock()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue