diff --git a/litellm/constants.py b/litellm/constants.py index 36fd8df3ea..184aa1b559 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -66,3 +66,9 @@ LITELLM_CHAT_PROVIDERS = [ "lm_studio", "galadriel", ] + + +########################### LiteLLM Proxy Specific Constants ########################### +MAX_SPENDLOG_ROWS_TO_QUERY = ( + 1_000_000 # if spendLogs has more than 1M rows, do not query the DB +) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 1a0ae26eae..899c9d0ccd 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -261,10 +261,6 @@ class LiteLLMRoutes(enum.Enum): # NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend master_key_only_routes = ["/global/spend/reset", "/key/list"] - sso_only_routes = [ - "/sso/get/ui_settings", - ] - management_routes = [ # key "/key/generate", "/key/{token_id}/regenerate", @@ -347,28 +343,24 @@ class LiteLLMRoutes(enum.Enum): "/health/services", ] + info_routes - internal_user_routes = ( - [ - "/key/generate", - "/key/{token_id}/regenerate", - "/key/update", - "/key/delete", - "/key/health", - "/key/info", - "/global/spend/tags", - "/global/spend/keys", - "/global/spend/models", - "/global/spend/provider", - "/global/spend/end_users", - "/global/activity", - "/global/activity/model", - ] - + spend_tracking_routes - + sso_only_routes - ) + internal_user_routes = [ + "/key/generate", + "/key/{token_id}/regenerate", + "/key/update", + "/key/delete", + "/key/health", + "/key/info", + "/global/spend/tags", + "/global/spend/keys", + "/global/spend/models", + "/global/spend/provider", + "/global/spend/end_users", + "/global/activity", + "/global/activity/model", + ] + spend_tracking_routes internal_user_view_only_routes = ( - spend_tracking_routes + global_spend_tracking_routes + sso_only_routes + spend_tracking_routes + global_spend_tracking_routes ) self_managed_routes = [ @@ -2205,3 +2197,11 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): providers: Dict[str, ProviderBudgetResponseObject] = ( {} ) # Dictionary mapping provider names to their budget configurations + + +class ProxyStateVariables(TypedDict): + """ + TypedDict for Proxy state variables. + """ + + spend_logs_row_count: int diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index cc0b42120e..046f94325f 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -474,17 +474,17 @@ def should_run_auth_on_pass_through_provider_route(route: str) -> bool: def _has_user_setup_sso(): """ - Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables. + Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables. Returns a boolean indicating whether SSO has been set up. """ microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) - ui_username = os.getenv("UI_USERNAME", None) + generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) sso_setup = ( (microsoft_client_id is not None) or (google_client_id is not None) - or (ui_username is not None) + or (generic_client_id is not None) ) return sso_setup diff --git a/litellm/proxy/auth/route_checks.py b/litellm/proxy/auth/route_checks.py index 9496776a82..4deb4468e0 100644 --- a/litellm/proxy/auth/route_checks.py +++ b/litellm/proxy/auth/route_checks.py @@ -65,8 +65,6 @@ class RouteChecks: pass elif route == "/team/info": pass # handled by function itself - elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value: - pass elif ( route in LiteLLMRoutes.global_spend_tracking_routes.value and getattr(valid_token, "permissions", None) is not None diff --git a/litellm/proxy/common_utils/proxy_state.py b/litellm/proxy/common_utils/proxy_state.py new file mode 100644 index 0000000000..edd18c603d --- /dev/null +++ b/litellm/proxy/common_utils/proxy_state.py @@ -0,0 +1,36 @@ +""" +This file is used to store the state variables of the proxy server. + +Example: `spend_logs_row_count` is used to store the number of rows in the `LiteLLM_SpendLogs` table. +""" + +from typing import Any, Literal + +from litellm.proxy._types import ProxyStateVariables + + +class ProxyState: + """ + Proxy state class has get/set methods for Proxy state variables. + """ + + # Note: mypy does not recognize when we fetch ProxyStateVariables.annotations.keys(), so we also need to add the valid keys here + valid_keys_literal = Literal["spend_logs_row_count"] + + def __init__(self) -> None: + self.proxy_state_variables: ProxyStateVariables = ProxyStateVariables( + spend_logs_row_count=0, + ) + + def get_proxy_state_variable( + self, + variable_name: valid_keys_literal, + ) -> Any: + return self.proxy_state_variables.get(variable_name, None) + + def set_proxy_state_variable( + self, + variable_name: valid_keys_literal, + value: Any, + ) -> None: + self.proxy_state_variables[variable_name] = value diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 9a49646e69..cec08ddcaa 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -15,6 +15,7 @@ from fastapi.responses import RedirectResponse import litellm from litellm._logging import verbose_proxy_logger +from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY from litellm.proxy._types import ( LitellmUserRoles, NewUserRequest, @@ -640,12 +641,15 @@ async def insert_sso_user( dependencies=[Depends(user_api_key_auth)], ) async def get_ui_settings(request: Request): - from litellm.proxy.proxy_server import general_settings + from litellm.proxy.proxy_server import general_settings, proxy_state _proxy_base_url = os.getenv("PROXY_BASE_URL", None) _logout_url = os.getenv("PROXY_LOGOUT_URL", None) _is_sso_enabled = _has_user_setup_sso() - + disable_expensive_db_queries = ( + proxy_state.get_proxy_state_variable("spend_logs_row_count") + > MAX_SPENDLOG_ROWS_TO_QUERY + ) default_team_disabled = general_settings.get("default_team_disabled", False) if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ: if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true": @@ -656,4 +660,8 @@ async def get_ui_settings(request: Request): "PROXY_LOGOUT_URL": _logout_url, "DEFAULT_TEAM_DISABLED": default_team_disabled, "SSO_ENABLED": _is_sso_enabled, + "NUM_SPEND_LOGS_ROWS": proxy_state.get_proxy_state_variable( + "spend_logs_row_count" + ), + "DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries, } diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 93df33d757..f002306367 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -164,6 +164,7 @@ from litellm.proxy.common_utils.load_config_utils import ( from litellm.proxy.common_utils.openai_endpoint_utils import ( remove_sensitive_info_from_deployment, ) +from litellm.proxy.common_utils.proxy_state import ProxyState from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config @@ -327,6 +328,7 @@ premium_user: bool = _license_check.is_premium() global_max_parallel_request_retries_env: Optional[str] = os.getenv( "LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES" ) +proxy_state = ProxyState() if global_max_parallel_request_retries_env is None: global_max_parallel_request_retries: int = 3 else: @@ -3047,6 +3049,10 @@ class ProxyStartupEvent: prisma_client.check_view_exists() ) # check if all necessary views exist. Don't block execution + asyncio.create_task( + prisma_client._set_spend_logs_row_count_in_proxy_state() + ) # set the spend logs row count in proxy state. Don't block execution + # run a health check to ensure the DB is ready await prisma_client.health_check() return prisma_client diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index b5f26cb126..5dcb3f84d3 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -2183,6 +2183,35 @@ class PrismaClient: ) raise e + async def _get_spend_logs_row_count(self) -> int: + try: + sql_query = """ + SELECT reltuples::BIGINT + FROM pg_class + WHERE oid = '"LiteLLM_SpendLogs"'::regclass; + """ + result = await self.db.query_raw(query=sql_query) + return result[0]["reltuples"] + except Exception as e: + verbose_proxy_logger.error( + f"Error getting LiteLLM_SpendLogs row count: {e}" + ) + return 0 + + async def _set_spend_logs_row_count_in_proxy_state(self) -> None: + """ + Set the `LiteLLM_SpendLogs`row count in proxy state. + + This is used later to determine if we should run expensive UI Usage queries. + """ + from litellm.proxy.proxy_server import proxy_state + + _num_spend_logs_rows = await self._get_spend_logs_row_count() + proxy_state.set_proxy_state_variable( + variable_name="spend_logs_row_count", + value=_num_spend_logs_rows, + ) + ### CUSTOM FILE ### def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: diff --git a/tests/proxy_unit_tests/test_proxy_server.py b/tests/proxy_unit_tests/test_proxy_server.py index 64bb67b58c..71579dd15b 100644 --- a/tests/proxy_unit_tests/test_proxy_server.py +++ b/tests/proxy_unit_tests/test_proxy_server.py @@ -10,6 +10,7 @@ import litellm.proxy.proxy_server load_dotenv() import io +import json import os # this file is to test litellm/proxy @@ -2064,7 +2065,7 @@ async def test_proxy_model_group_info_rerank(prisma_client): @pytest.mark.asyncio async def test_proxy_server_prisma_setup(): - from litellm.proxy.proxy_server import ProxyStartupEvent + from litellm.proxy.proxy_server import ProxyStartupEvent, proxy_state from litellm.proxy.utils import ProxyLogging from litellm.caching import DualCache @@ -2077,6 +2078,9 @@ async def test_proxy_server_prisma_setup(): mock_client.connect = AsyncMock() # Mock the connect method mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method mock_client.health_check = AsyncMock() # Mock the health_check method + mock_client._set_spend_logs_row_count_in_proxy_state = ( + AsyncMock() + ) # Mock the _set_spend_logs_row_count_in_proxy_state method await ProxyStartupEvent._setup_prisma_client( database_url=os.getenv("DATABASE_URL"), @@ -2092,6 +2096,10 @@ async def test_proxy_server_prisma_setup(): # This is how we ensure the DB is ready before proceeding mock_client.health_check.assert_called_once() + # check that the spend logs row count is set in proxy state + mock_client._set_spend_logs_row_count_in_proxy_state.assert_called_once() + assert proxy_state.get_proxy_state_variable("spend_logs_row_count") is not None + @pytest.mark.asyncio async def test_proxy_server_prisma_setup_invalid_db(): @@ -2125,3 +2133,57 @@ async def test_proxy_server_prisma_setup_invalid_db(): if _old_db_url: os.environ["DATABASE_URL"] = _old_db_url + + +@pytest.mark.asyncio +async def test_get_ui_settings_spend_logs_threshold(): + """ + Test that get_ui_settings correctly sets DISABLE_EXPENSIVE_DB_QUERIES based on spend_logs_row_count threshold + """ + from litellm.proxy.management_endpoints.ui_sso import get_ui_settings + from litellm.proxy.proxy_server import proxy_state + from fastapi import Request + from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY + + # Create a mock request + mock_request = Request( + scope={ + "type": "http", + "headers": [], + "method": "GET", + "scheme": "http", + "server": ("testserver", 80), + "path": "/sso/get/ui_settings", + "query_string": b"", + } + ) + + # Test case 1: When spend_logs_row_count > MAX_SPENDLOG_ROWS_TO_QUERY + proxy_state.set_proxy_state_variable( + "spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY + 1 + ) + response = await get_ui_settings(mock_request) + print("response from get_ui_settings", json.dumps(response, indent=4)) + assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is True + assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY + 1 + + # Test case 2: When spend_logs_row_count < MAX_SPENDLOG_ROWS_TO_QUERY + proxy_state.set_proxy_state_variable( + "spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY - 1 + ) + response = await get_ui_settings(mock_request) + print("response from get_ui_settings", json.dumps(response, indent=4)) + assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False + assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY - 1 + + # Test case 3: Edge case - exactly MAX_SPENDLOG_ROWS_TO_QUERY + proxy_state.set_proxy_state_variable( + "spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY + ) + response = await get_ui_settings(mock_request) + print("response from get_ui_settings", json.dumps(response, indent=4)) + assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False + assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY + + # Clean up + proxy_state.set_proxy_state_variable("spend_logs_row_count", 0) diff --git a/ui/litellm-dashboard/src/components/create_user_button.tsx b/ui/litellm-dashboard/src/components/create_user_button.tsx index 21b99c6668..ade39699b5 100644 --- a/ui/litellm-dashboard/src/components/create_user_button.tsx +++ b/ui/litellm-dashboard/src/components/create_user_button.tsx @@ -17,7 +17,7 @@ import { userCreateCall, modelAvailableCall, invitationCreateCall, - getProxyBaseUrlAndLogoutUrl, + getProxyUISettings, } from "./networking"; const { Option } = Select; @@ -82,7 +82,7 @@ const Createuser: React.FC = ({ setUserModels(availableModels); // get ui settings - const uiSettingsResponse = await getProxyBaseUrlAndLogoutUrl(accessToken); + const uiSettingsResponse = await getProxyUISettings(accessToken); console.log("uiSettingsResponse:", uiSettingsResponse); setUISettings(uiSettingsResponse); diff --git a/ui/litellm-dashboard/src/components/dashboard_default_team.tsx b/ui/litellm-dashboard/src/components/dashboard_default_team.tsx index 04481eadc9..5235754a7d 100644 --- a/ui/litellm-dashboard/src/components/dashboard_default_team.tsx +++ b/ui/litellm-dashboard/src/components/dashboard_default_team.tsx @@ -1,7 +1,7 @@ import React, { useState, useEffect } from "react"; import { Select, SelectItem, Text, Title } from "@tremor/react"; import { ProxySettings, UserInfo } from "./user_dashboard"; -import { getProxyBaseUrlAndLogoutUrl } from "./networking" +import { getProxyUISettings } from "./networking" interface DashboardTeamProps { teams: Object[] | null; @@ -39,7 +39,7 @@ const DashboardTeam: React.FC = ({ const getProxySettings = async () => { if (proxySettings === null && accessToken) { - const proxy_settings: ProxySettings = await getProxyBaseUrlAndLogoutUrl(accessToken); + const proxy_settings: ProxySettings = await getProxyUISettings(accessToken); setProxySettings(proxy_settings); } }; diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 990345d441..d91b5fe683 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -2818,7 +2818,7 @@ export const healthCheckCall = async (accessToken: String) => { } }; -export const getProxyBaseUrlAndLogoutUrl = async ( +export const getProxyUISettings = async ( accessToken: String, ) => { /** diff --git a/ui/litellm-dashboard/src/components/usage.tsx b/ui/litellm-dashboard/src/components/usage.tsx index fbe43aa12f..5c9cac84bf 100644 --- a/ui/litellm-dashboard/src/components/usage.tsx +++ b/ui/litellm-dashboard/src/components/usage.tsx @@ -3,6 +3,7 @@ import { BarChart, BarList, Card, Title, Table, TableHead, TableHeaderCell, Tabl import React, { useState, useEffect } from "react"; import ViewUserSpend from "./view_user_spend"; +import { ProxySettings } from "./user_dashboard"; import { Grid, Col, Text, LineChart, TabPanel, TabPanels, @@ -35,6 +36,7 @@ import { adminspendByProvider, adminGlobalActivity, adminGlobalActivityPerModel, + getProxyUISettings } from "./networking"; import { start } from "repl"; console.log("process.env.NODE_ENV", process.env.NODE_ENV); @@ -166,6 +168,7 @@ const UsagePage: React.FC = ({ from: new Date(Date.now() - 7 * 24 * 60 * 60 * 1000), to: new Date(), }); + const [proxySettings, setProxySettings] = useState(null); const firstDay = new Date( currentDate.getFullYear(), @@ -194,6 +197,22 @@ const UsagePage: React.FC = ({ return formatter.format(number); } + useEffect(() => { + const fetchProxySettings = async () => { + if (accessToken) { + try { + const proxy_settings: ProxySettings = await getProxyUISettings(accessToken); + console.log("usage tab: proxy_settings", proxy_settings); + setProxySettings(proxy_settings); + } catch (error) { + console.error("Error fetching proxy settings:", error); + } + } + }; + fetchProxySettings(); + }, [accessToken]); + + useEffect(() => { updateTagSpendData(dateValue.from, dateValue.to); }, [dateValue, selectedTags]); @@ -386,6 +405,9 @@ const UsagePage: React.FC = ({ useEffect(() => { if (accessToken && token && userRole && userID) { + if (proxySettings?.DISABLE_EXPENSIVE_DB_QUERIES) { + return; // Don't run expensive queries + } fetchOverallSpend(); @@ -405,9 +427,29 @@ const UsagePage: React.FC = ({ }, [accessToken, token, userRole, userID, startTime, endTime]); + if (proxySettings?.DISABLE_EXPENSIVE_DB_QUERIES) { + return ( +
+ + Database Query Limit Reached + + SpendLogs in DB has {proxySettings.NUM_SPEND_LOGS_ROWS} rows. +

+ Please follow our guide to view usage when SpendLogs has more than 1M rows. +
+ +
+
+ ); + } + + return ( -
- +
All Up diff --git a/ui/litellm-dashboard/src/components/user_dashboard.tsx b/ui/litellm-dashboard/src/components/user_dashboard.tsx index 527d16c231..9c213bc3c5 100644 --- a/ui/litellm-dashboard/src/components/user_dashboard.tsx +++ b/ui/litellm-dashboard/src/components/user_dashboard.tsx @@ -4,7 +4,7 @@ import { userInfoCall, modelAvailableCall, getTotalSpendCall, - getProxyBaseUrlAndLogoutUrl, + getProxyUISettings, } from "./networking"; import { Grid, Col, Card, Text, Title } from "@tremor/react"; import CreateKey from "./create_key_button"; @@ -28,6 +28,8 @@ export interface ProxySettings { PROXY_LOGOUT_URL: string | null; DEFAULT_TEAM_DISABLED: boolean; SSO_ENABLED: boolean; + DISABLE_EXPENSIVE_DB_QUERIES: boolean; + NUM_SPEND_LOGS_ROWS: number; } @@ -172,7 +174,7 @@ const UserDashboard: React.FC = ({ } else { const fetchData = async () => { try { - const proxy_settings: ProxySettings = await getProxyBaseUrlAndLogoutUrl(accessToken); + const proxy_settings: ProxySettings = await getProxyUISettings(accessToken); setProxySettings(proxy_settings); const response = await userInfoCall(