mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(common_daily_activity.py): support empty entity id field (#10175)
* fix(common_daily_activity.py): support empty entity id field allows returning empty response when user is not admin and does not belong to any team * test(test_common_daily_activity.py): add unit testing
This commit is contained in:
parent
72f6bd3972
commit
e0a613f88a
5 changed files with 69 additions and 8 deletions
File diff suppressed because one or more lines are too long
|
@ -644,6 +644,7 @@ async def get_user_object(
|
||||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||||
sso_user_id: Optional[str] = None,
|
sso_user_id: Optional[str] = None,
|
||||||
user_email: Optional[str] = None,
|
user_email: Optional[str] = None,
|
||||||
|
check_db_only: Optional[bool] = None,
|
||||||
) -> Optional[LiteLLM_UserTable]:
|
) -> Optional[LiteLLM_UserTable]:
|
||||||
"""
|
"""
|
||||||
- Check if user id in proxy User Table
|
- Check if user id in proxy User Table
|
||||||
|
@ -655,12 +656,13 @@ async def get_user_object(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
if not check_db_only:
|
||||||
if cached_user_obj is not None:
|
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||||
if isinstance(cached_user_obj, dict):
|
if cached_user_obj is not None:
|
||||||
return LiteLLM_UserTable(**cached_user_obj)
|
if isinstance(cached_user_obj, dict):
|
||||||
elif isinstance(cached_user_obj, LiteLLM_UserTable):
|
return LiteLLM_UserTable(**cached_user_obj)
|
||||||
return cached_user_obj
|
elif isinstance(cached_user_obj, LiteLLM_UserTable):
|
||||||
|
return cached_user_obj
|
||||||
# else, check db
|
# else, check db
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise Exception("No db connected")
|
raise Exception("No db connected")
|
||||||
|
|
|
@ -154,7 +154,7 @@ async def get_daily_activity(
|
||||||
where_conditions["model"] = model
|
where_conditions["model"] = model
|
||||||
if api_key:
|
if api_key:
|
||||||
where_conditions["api_key"] = api_key
|
where_conditions["api_key"] = api_key
|
||||||
if entity_id:
|
if entity_id is not None:
|
||||||
if isinstance(entity_id, list):
|
if isinstance(entity_id, list):
|
||||||
where_conditions[entity_id_field] = {"in": entity_id}
|
where_conditions[entity_id_field] = {"in": entity_id}
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -2149,6 +2149,7 @@ async def get_team_daily_activity(
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
check_db_only=True,
|
||||||
)
|
)
|
||||||
if user_info is None:
|
if user_info is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -2157,6 +2158,7 @@ async def get_team_daily_activity(
|
||||||
"error": "User= {} not found".format(user_api_key_dict.user_id)
|
"error": "User= {} not found".format(user_api_key_dict.user_id)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if team_ids_list is None:
|
if team_ids_list is None:
|
||||||
team_ids_list = user_info.teams
|
team_ids_list = user_info.teams
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
|
||||||
|
from litellm.proxy.proxy_server import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_daily_activity_empty_entity_id_list():
|
||||||
|
# Mock PrismaClient
|
||||||
|
mock_prisma = MagicMock()
|
||||||
|
mock_prisma.db = MagicMock()
|
||||||
|
|
||||||
|
# Mock the table methods
|
||||||
|
mock_table = MagicMock()
|
||||||
|
mock_table.count = AsyncMock(return_value=0)
|
||||||
|
mock_table.find_many = AsyncMock(return_value=[])
|
||||||
|
mock_prisma.db.litellm_verificationtoken = MagicMock()
|
||||||
|
mock_prisma.db.litellm_verificationtoken.find_many = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
# Set the table name dynamically
|
||||||
|
mock_prisma.db.litellm_dailyspend = mock_table
|
||||||
|
|
||||||
|
# Call the function with empty entity_id list
|
||||||
|
result = await get_daily_activity(
|
||||||
|
prisma_client=mock_prisma,
|
||||||
|
table_name="litellm_dailyspend",
|
||||||
|
entity_id_field="team_id",
|
||||||
|
entity_id=[],
|
||||||
|
entity_metadata_field=None,
|
||||||
|
start_date="2024-01-01",
|
||||||
|
end_date="2024-01-02",
|
||||||
|
model=None,
|
||||||
|
api_key=None,
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the where conditions were set correctly
|
||||||
|
mock_table.find_many.assert_called_once()
|
||||||
|
call_args = mock_table.find_many.call_args[1]
|
||||||
|
where_conditions = call_args["where"]
|
||||||
|
|
||||||
|
# Check that team_id is set to empty list
|
||||||
|
assert "team_id" in where_conditions
|
||||||
|
assert where_conditions["team_id"] == {"in": []}
|
Loading…
Add table
Add a link
Reference in a new issue