(feat) UI - Disable Usage Tab once SpendLogs is 1M+ Rows (#7208)

* use utils to set proxy spend logs row count

* store proxy state variables

* fix check for _has_user_setup_sso

* fix proxyStateVariables

* fix dup code

* rename getProxyUISettings

* add fixes

* ui emit num spend logs rows

* test_proxy_server_prisma_setup

* use MAX_SPENDLOG_ROWS_TO_QUERY to constants

* test_get_ui_settings_spend_logs_threshold
This commit is contained in:
Ishaan Jaff 2024-12-12 18:43:17 -08:00 committed by GitHub
parent ce69357e9d
commit b889d7c72f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 230 additions and 41 deletions

View file

@ -66,3 +66,9 @@ LITELLM_CHAT_PROVIDERS = [
"lm_studio", "lm_studio",
"galadriel", "galadriel",
] ]
########################### LiteLLM Proxy Specific Constants ###########################
MAX_SPENDLOG_ROWS_TO_QUERY = (
1_000_000 # if spendLogs has more than 1M rows, do not query the DB
)

View file

@ -261,10 +261,6 @@ class LiteLLMRoutes(enum.Enum):
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend # NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
master_key_only_routes = ["/global/spend/reset", "/key/list"] master_key_only_routes = ["/global/spend/reset", "/key/list"]
sso_only_routes = [
"/sso/get/ui_settings",
]
management_routes = [ # key management_routes = [ # key
"/key/generate", "/key/generate",
"/key/{token_id}/regenerate", "/key/{token_id}/regenerate",
@ -347,8 +343,7 @@ class LiteLLMRoutes(enum.Enum):
"/health/services", "/health/services",
] + info_routes ] + info_routes
internal_user_routes = ( internal_user_routes = [
[
"/key/generate", "/key/generate",
"/key/{token_id}/regenerate", "/key/{token_id}/regenerate",
"/key/update", "/key/update",
@ -362,13 +357,10 @@ class LiteLLMRoutes(enum.Enum):
"/global/spend/end_users", "/global/spend/end_users",
"/global/activity", "/global/activity",
"/global/activity/model", "/global/activity/model",
] ] + spend_tracking_routes
+ spend_tracking_routes
+ sso_only_routes
)
internal_user_view_only_routes = ( internal_user_view_only_routes = (
spend_tracking_routes + global_spend_tracking_routes + sso_only_routes spend_tracking_routes + global_spend_tracking_routes
) )
self_managed_routes = [ self_managed_routes = [
@ -2205,3 +2197,11 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
providers: Dict[str, ProviderBudgetResponseObject] = ( providers: Dict[str, ProviderBudgetResponseObject] = (
{} {}
) # Dictionary mapping provider names to their budget configurations ) # Dictionary mapping provider names to their budget configurations
class ProxyStateVariables(TypedDict):
"""
TypedDict for Proxy state variables.
"""
spend_logs_row_count: int

View file

@ -474,17 +474,17 @@ def should_run_auth_on_pass_through_provider_route(route: str) -> bool:
def _has_user_setup_sso(): def _has_user_setup_sso():
""" """
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables. Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
Returns a boolean indicating whether SSO has been set up. Returns a boolean indicating whether SSO has been set up.
""" """
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
ui_username = os.getenv("UI_USERNAME", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
sso_setup = ( sso_setup = (
(microsoft_client_id is not None) (microsoft_client_id is not None)
or (google_client_id is not None) or (google_client_id is not None)
or (ui_username is not None) or (generic_client_id is not None)
) )
return sso_setup return sso_setup

View file

@ -65,8 +65,6 @@ class RouteChecks:
pass pass
elif route == "/team/info": elif route == "/team/info":
pass # handled by function itself pass # handled by function itself
elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
pass
elif ( elif (
route in LiteLLMRoutes.global_spend_tracking_routes.value route in LiteLLMRoutes.global_spend_tracking_routes.value
and getattr(valid_token, "permissions", None) is not None and getattr(valid_token, "permissions", None) is not None

View file

@ -0,0 +1,36 @@
"""
This file is used to store the state variables of the proxy server.
Example: `spend_logs_row_count` is used to store the number of rows in the `LiteLLM_SpendLogs` table.
"""
from typing import Any, Literal
from litellm.proxy._types import ProxyStateVariables
class ProxyState:
"""
Proxy state class has get/set methods for Proxy state variables.
"""
# Note: mypy does not recognize when we fetch ProxyStateVariables.annotations.keys(), so we also need to add the valid keys here
valid_keys_literal = Literal["spend_logs_row_count"]
def __init__(self) -> None:
self.proxy_state_variables: ProxyStateVariables = ProxyStateVariables(
spend_logs_row_count=0,
)
def get_proxy_state_variable(
self,
variable_name: valid_keys_literal,
) -> Any:
return self.proxy_state_variables.get(variable_name, None)
def set_proxy_state_variable(
self,
variable_name: valid_keys_literal,
value: Any,
) -> None:
self.proxy_state_variables[variable_name] = value

View file

@ -15,6 +15,7 @@ from fastapi.responses import RedirectResponse
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
from litellm.proxy._types import ( from litellm.proxy._types import (
LitellmUserRoles, LitellmUserRoles,
NewUserRequest, NewUserRequest,
@ -640,12 +641,15 @@ async def insert_sso_user(
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def get_ui_settings(request: Request): async def get_ui_settings(request: Request):
from litellm.proxy.proxy_server import general_settings from litellm.proxy.proxy_server import general_settings, proxy_state
_proxy_base_url = os.getenv("PROXY_BASE_URL", None) _proxy_base_url = os.getenv("PROXY_BASE_URL", None)
_logout_url = os.getenv("PROXY_LOGOUT_URL", None) _logout_url = os.getenv("PROXY_LOGOUT_URL", None)
_is_sso_enabled = _has_user_setup_sso() _is_sso_enabled = _has_user_setup_sso()
disable_expensive_db_queries = (
proxy_state.get_proxy_state_variable("spend_logs_row_count")
> MAX_SPENDLOG_ROWS_TO_QUERY
)
default_team_disabled = general_settings.get("default_team_disabled", False) default_team_disabled = general_settings.get("default_team_disabled", False)
if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ: if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ:
if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true": if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true":
@ -656,4 +660,8 @@ async def get_ui_settings(request: Request):
"PROXY_LOGOUT_URL": _logout_url, "PROXY_LOGOUT_URL": _logout_url,
"DEFAULT_TEAM_DISABLED": default_team_disabled, "DEFAULT_TEAM_DISABLED": default_team_disabled,
"SSO_ENABLED": _is_sso_enabled, "SSO_ENABLED": _is_sso_enabled,
"NUM_SPEND_LOGS_ROWS": proxy_state.get_proxy_state_variable(
"spend_logs_row_count"
),
"DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries,
} }

View file

@ -164,6 +164,7 @@ from litellm.proxy.common_utils.load_config_utils import (
from litellm.proxy.common_utils.openai_endpoint_utils import ( from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment, remove_sensitive_info_from_deployment,
) )
from litellm.proxy.common_utils.proxy_state import ProxyState
from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
@ -327,6 +328,7 @@ premium_user: bool = _license_check.is_premium()
global_max_parallel_request_retries_env: Optional[str] = os.getenv( global_max_parallel_request_retries_env: Optional[str] = os.getenv(
"LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES" "LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES"
) )
proxy_state = ProxyState()
if global_max_parallel_request_retries_env is None: if global_max_parallel_request_retries_env is None:
global_max_parallel_request_retries: int = 3 global_max_parallel_request_retries: int = 3
else: else:
@ -3047,6 +3049,10 @@ class ProxyStartupEvent:
prisma_client.check_view_exists() prisma_client.check_view_exists()
) # check if all necessary views exist. Don't block execution ) # check if all necessary views exist. Don't block execution
asyncio.create_task(
prisma_client._set_spend_logs_row_count_in_proxy_state()
) # set the spend logs row count in proxy state. Don't block execution
# run a health check to ensure the DB is ready # run a health check to ensure the DB is ready
await prisma_client.health_check() await prisma_client.health_check()
return prisma_client return prisma_client

View file

@ -2183,6 +2183,35 @@ class PrismaClient:
) )
raise e raise e
async def _get_spend_logs_row_count(self) -> int:
try:
sql_query = """
SELECT reltuples::BIGINT
FROM pg_class
WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
"""
result = await self.db.query_raw(query=sql_query)
return result[0]["reltuples"]
except Exception as e:
verbose_proxy_logger.error(
f"Error getting LiteLLM_SpendLogs row count: {e}"
)
return 0
async def _set_spend_logs_row_count_in_proxy_state(self) -> None:
"""
Set the `LiteLLM_SpendLogs`row count in proxy state.
This is used later to determine if we should run expensive UI Usage queries.
"""
from litellm.proxy.proxy_server import proxy_state
_num_spend_logs_rows = await self._get_spend_logs_row_count()
proxy_state.set_proxy_state_variable(
variable_name="spend_logs_row_count",
value=_num_spend_logs_rows,
)
### CUSTOM FILE ### ### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:

View file

@ -10,6 +10,7 @@ import litellm.proxy.proxy_server
load_dotenv() load_dotenv()
import io import io
import json
import os import os
# this file is to test litellm/proxy # this file is to test litellm/proxy
@ -2064,7 +2065,7 @@ async def test_proxy_model_group_info_rerank(prisma_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_proxy_server_prisma_setup(): async def test_proxy_server_prisma_setup():
from litellm.proxy.proxy_server import ProxyStartupEvent from litellm.proxy.proxy_server import ProxyStartupEvent, proxy_state
from litellm.proxy.utils import ProxyLogging from litellm.proxy.utils import ProxyLogging
from litellm.caching import DualCache from litellm.caching import DualCache
@ -2077,6 +2078,9 @@ async def test_proxy_server_prisma_setup():
mock_client.connect = AsyncMock() # Mock the connect method mock_client.connect = AsyncMock() # Mock the connect method
mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method
mock_client.health_check = AsyncMock() # Mock the health_check method mock_client.health_check = AsyncMock() # Mock the health_check method
mock_client._set_spend_logs_row_count_in_proxy_state = (
AsyncMock()
) # Mock the _set_spend_logs_row_count_in_proxy_state method
await ProxyStartupEvent._setup_prisma_client( await ProxyStartupEvent._setup_prisma_client(
database_url=os.getenv("DATABASE_URL"), database_url=os.getenv("DATABASE_URL"),
@ -2092,6 +2096,10 @@ async def test_proxy_server_prisma_setup():
# This is how we ensure the DB is ready before proceeding # This is how we ensure the DB is ready before proceeding
mock_client.health_check.assert_called_once() mock_client.health_check.assert_called_once()
# check that the spend logs row count is set in proxy state
mock_client._set_spend_logs_row_count_in_proxy_state.assert_called_once()
assert proxy_state.get_proxy_state_variable("spend_logs_row_count") is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_proxy_server_prisma_setup_invalid_db(): async def test_proxy_server_prisma_setup_invalid_db():
@ -2125,3 +2133,57 @@ async def test_proxy_server_prisma_setup_invalid_db():
if _old_db_url: if _old_db_url:
os.environ["DATABASE_URL"] = _old_db_url os.environ["DATABASE_URL"] = _old_db_url
@pytest.mark.asyncio
async def test_get_ui_settings_spend_logs_threshold():
"""
Test that get_ui_settings correctly sets DISABLE_EXPENSIVE_DB_QUERIES based on spend_logs_row_count threshold
"""
from litellm.proxy.management_endpoints.ui_sso import get_ui_settings
from litellm.proxy.proxy_server import proxy_state
from fastapi import Request
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
# Create a mock request
mock_request = Request(
scope={
"type": "http",
"headers": [],
"method": "GET",
"scheme": "http",
"server": ("testserver", 80),
"path": "/sso/get/ui_settings",
"query_string": b"",
}
)
# Test case 1: When spend_logs_row_count > MAX_SPENDLOG_ROWS_TO_QUERY
proxy_state.set_proxy_state_variable(
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY + 1
)
response = await get_ui_settings(mock_request)
print("response from get_ui_settings", json.dumps(response, indent=4))
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is True
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY + 1
# Test case 2: When spend_logs_row_count < MAX_SPENDLOG_ROWS_TO_QUERY
proxy_state.set_proxy_state_variable(
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY - 1
)
response = await get_ui_settings(mock_request)
print("response from get_ui_settings", json.dumps(response, indent=4))
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY - 1
# Test case 3: Edge case - exactly MAX_SPENDLOG_ROWS_TO_QUERY
proxy_state.set_proxy_state_variable(
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY
)
response = await get_ui_settings(mock_request)
print("response from get_ui_settings", json.dumps(response, indent=4))
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY
# Clean up
proxy_state.set_proxy_state_variable("spend_logs_row_count", 0)

View file

@ -17,7 +17,7 @@ import {
userCreateCall, userCreateCall,
modelAvailableCall, modelAvailableCall,
invitationCreateCall, invitationCreateCall,
getProxyBaseUrlAndLogoutUrl, getProxyUISettings,
} from "./networking"; } from "./networking";
const { Option } = Select; const { Option } = Select;
@ -82,7 +82,7 @@ const Createuser: React.FC<CreateuserProps> = ({
setUserModels(availableModels); setUserModels(availableModels);
// get ui settings // get ui settings
const uiSettingsResponse = await getProxyBaseUrlAndLogoutUrl(accessToken); const uiSettingsResponse = await getProxyUISettings(accessToken);
console.log("uiSettingsResponse:", uiSettingsResponse); console.log("uiSettingsResponse:", uiSettingsResponse);
setUISettings(uiSettingsResponse); setUISettings(uiSettingsResponse);

View file

@ -1,7 +1,7 @@
import React, { useState, useEffect } from "react"; import React, { useState, useEffect } from "react";
import { Select, SelectItem, Text, Title } from "@tremor/react"; import { Select, SelectItem, Text, Title } from "@tremor/react";
import { ProxySettings, UserInfo } from "./user_dashboard"; import { ProxySettings, UserInfo } from "./user_dashboard";
import { getProxyBaseUrlAndLogoutUrl } from "./networking" import { getProxyUISettings } from "./networking"
interface DashboardTeamProps { interface DashboardTeamProps {
teams: Object[] | null; teams: Object[] | null;
@ -39,7 +39,7 @@ const DashboardTeam: React.FC<DashboardTeamProps> = ({
const getProxySettings = async () => { const getProxySettings = async () => {
if (proxySettings === null && accessToken) { if (proxySettings === null && accessToken) {
const proxy_settings: ProxySettings = await getProxyBaseUrlAndLogoutUrl(accessToken); const proxy_settings: ProxySettings = await getProxyUISettings(accessToken);
setProxySettings(proxy_settings); setProxySettings(proxy_settings);
} }
}; };

View file

@ -2818,7 +2818,7 @@ export const healthCheckCall = async (accessToken: String) => {
} }
}; };
export const getProxyBaseUrlAndLogoutUrl = async ( export const getProxyUISettings = async (
accessToken: String, accessToken: String,
) => { ) => {
/** /**

View file

@ -3,6 +3,7 @@ import { BarChart, BarList, Card, Title, Table, TableHead, TableHeaderCell, Tabl
import React, { useState, useEffect } from "react"; import React, { useState, useEffect } from "react";
import ViewUserSpend from "./view_user_spend"; import ViewUserSpend from "./view_user_spend";
import { ProxySettings } from "./user_dashboard";
import { import {
Grid, Col, Text, Grid, Col, Text,
LineChart, TabPanel, TabPanels, LineChart, TabPanel, TabPanels,
@ -35,6 +36,7 @@ import {
adminspendByProvider, adminspendByProvider,
adminGlobalActivity, adminGlobalActivity,
adminGlobalActivityPerModel, adminGlobalActivityPerModel,
getProxyUISettings
} from "./networking"; } from "./networking";
import { start } from "repl"; import { start } from "repl";
console.log("process.env.NODE_ENV", process.env.NODE_ENV); console.log("process.env.NODE_ENV", process.env.NODE_ENV);
@ -166,6 +168,7 @@ const UsagePage: React.FC<UsagePageProps> = ({
from: new Date(Date.now() - 7 * 24 * 60 * 60 * 1000), from: new Date(Date.now() - 7 * 24 * 60 * 60 * 1000),
to: new Date(), to: new Date(),
}); });
const [proxySettings, setProxySettings] = useState<ProxySettings | null>(null);
const firstDay = new Date( const firstDay = new Date(
currentDate.getFullYear(), currentDate.getFullYear(),
@ -194,6 +197,22 @@ const UsagePage: React.FC<UsagePageProps> = ({
return formatter.format(number); return formatter.format(number);
} }
useEffect(() => {
const fetchProxySettings = async () => {
if (accessToken) {
try {
const proxy_settings: ProxySettings = await getProxyUISettings(accessToken);
console.log("usage tab: proxy_settings", proxy_settings);
setProxySettings(proxy_settings);
} catch (error) {
console.error("Error fetching proxy settings:", error);
}
}
};
fetchProxySettings();
}, [accessToken]);
useEffect(() => { useEffect(() => {
updateTagSpendData(dateValue.from, dateValue.to); updateTagSpendData(dateValue.from, dateValue.to);
}, [dateValue, selectedTags]); }, [dateValue, selectedTags]);
@ -386,6 +405,9 @@ const UsagePage: React.FC<UsagePageProps> = ({
useEffect(() => { useEffect(() => {
if (accessToken && token && userRole && userID) { if (accessToken && token && userRole && userID) {
if (proxySettings?.DISABLE_EXPENSIVE_DB_QUERIES) {
return; // Don't run expensive queries
}
fetchOverallSpend(); fetchOverallSpend();
@ -405,9 +427,29 @@ const UsagePage: React.FC<UsagePageProps> = ({
}, [accessToken, token, userRole, userID, startTime, endTime]); }, [accessToken, token, userRole, userID, startTime, endTime]);
if (proxySettings?.DISABLE_EXPENSIVE_DB_QUERIES) {
return ( return (
<div style={{ width: "100%" }} className="p-8"> <div style={{ width: "100%" }} className="p-8">
<Card>
<Title>Database Query Limit Reached</Title>
<Text className="mt-4">
SpendLogs in DB has {proxySettings.NUM_SPEND_LOGS_ROWS} rows.
<br></br>
Please follow our guide to view usage when SpendLogs has more than 1M rows.
</Text>
<Button className="mt-4">
<a href="https://docs.litellm.ai/docs/proxy/spending_monitoring" target="_blank">
View Usage Guide
</a>
</Button>
</Card>
</div>
);
}
return (
<div style={{ width: "100%" }} className="p-8">
<TabGroup> <TabGroup>
<TabList className="mt-2"> <TabList className="mt-2">
<Tab>All Up</Tab> <Tab>All Up</Tab>

View file

@ -4,7 +4,7 @@ import {
userInfoCall, userInfoCall,
modelAvailableCall, modelAvailableCall,
getTotalSpendCall, getTotalSpendCall,
getProxyBaseUrlAndLogoutUrl, getProxyUISettings,
} from "./networking"; } from "./networking";
import { Grid, Col, Card, Text, Title } from "@tremor/react"; import { Grid, Col, Card, Text, Title } from "@tremor/react";
import CreateKey from "./create_key_button"; import CreateKey from "./create_key_button";
@ -28,6 +28,8 @@ export interface ProxySettings {
PROXY_LOGOUT_URL: string | null; PROXY_LOGOUT_URL: string | null;
DEFAULT_TEAM_DISABLED: boolean; DEFAULT_TEAM_DISABLED: boolean;
SSO_ENABLED: boolean; SSO_ENABLED: boolean;
DISABLE_EXPENSIVE_DB_QUERIES: boolean;
NUM_SPEND_LOGS_ROWS: number;
} }
@ -172,7 +174,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
} else { } else {
const fetchData = async () => { const fetchData = async () => {
try { try {
const proxy_settings: ProxySettings = await getProxyBaseUrlAndLogoutUrl(accessToken); const proxy_settings: ProxySettings = await getProxyUISettings(accessToken);
setProxySettings(proxy_settings); setProxySettings(proxy_settings);
const response = await userInfoCall( const response = await userInfoCall(