(feat) UI - Disable Usage Tab once SpendLogs is 1M+ Rows (#7208)

* use utils to set proxy spend logs row count

* store proxy state variables

* fix check for _has_user_setup_sso

* fix proxyStateVariables

* fix dup code

* rename getProxyUISettings

* add fixes

* ui emit num spend logs rows

* test_proxy_server_prisma_setup

* use MAX_SPENDLOG_ROWS_TO_QUERY to constants

* test_get_ui_settings_spend_logs_threshold
This commit is contained in:
Ishaan Jaff 2024-12-12 18:43:17 -08:00 committed by GitHub
parent ce69357e9d
commit b889d7c72f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 230 additions and 41 deletions

View file

@ -66,3 +66,9 @@ LITELLM_CHAT_PROVIDERS = [
"lm_studio",
"galadriel",
]
########################### LiteLLM Proxy Specific Constants ###########################
MAX_SPENDLOG_ROWS_TO_QUERY = (
1_000_000 # if spendLogs has more than 1M rows, do not query the DB
)

View file

@ -261,10 +261,6 @@ class LiteLLMRoutes(enum.Enum):
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
master_key_only_routes = ["/global/spend/reset", "/key/list"]
sso_only_routes = [
"/sso/get/ui_settings",
]
management_routes = [ # key
"/key/generate",
"/key/{token_id}/regenerate",
@ -347,28 +343,24 @@ class LiteLLMRoutes(enum.Enum):
"/health/services",
] + info_routes
internal_user_routes = (
[
"/key/generate",
"/key/{token_id}/regenerate",
"/key/update",
"/key/delete",
"/key/health",
"/key/info",
"/global/spend/tags",
"/global/spend/keys",
"/global/spend/models",
"/global/spend/provider",
"/global/spend/end_users",
"/global/activity",
"/global/activity/model",
]
+ spend_tracking_routes
+ sso_only_routes
)
internal_user_routes = [
"/key/generate",
"/key/{token_id}/regenerate",
"/key/update",
"/key/delete",
"/key/health",
"/key/info",
"/global/spend/tags",
"/global/spend/keys",
"/global/spend/models",
"/global/spend/provider",
"/global/spend/end_users",
"/global/activity",
"/global/activity/model",
] + spend_tracking_routes
internal_user_view_only_routes = (
spend_tracking_routes + global_spend_tracking_routes + sso_only_routes
spend_tracking_routes + global_spend_tracking_routes
)
self_managed_routes = [
@ -2205,3 +2197,11 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
providers: Dict[str, ProviderBudgetResponseObject] = (
{}
) # Dictionary mapping provider names to their budget configurations
class ProxyStateVariables(TypedDict):
"""
TypedDict for Proxy state variables.
"""
spend_logs_row_count: int

View file

@ -474,17 +474,17 @@ def should_run_auth_on_pass_through_provider_route(route: str) -> bool:
def _has_user_setup_sso():
"""
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID, and UI username environment variables.
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
Returns a boolean indicating whether SSO has been set up.
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
ui_username = os.getenv("UI_USERNAME", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
sso_setup = (
(microsoft_client_id is not None)
or (google_client_id is not None)
or (ui_username is not None)
or (generic_client_id is not None)
)
return sso_setup

View file

@ -65,8 +65,6 @@ class RouteChecks:
pass
elif route == "/team/info":
pass # handled by function itself
elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value:
pass
elif (
route in LiteLLMRoutes.global_spend_tracking_routes.value
and getattr(valid_token, "permissions", None) is not None

View file

@ -0,0 +1,36 @@
"""
This file is used to store the state variables of the proxy server.
Example: `spend_logs_row_count` is used to store the number of rows in the `LiteLLM_SpendLogs` table.
"""
from typing import Any, Literal
from litellm.proxy._types import ProxyStateVariables
class ProxyState:
"""
Proxy state class has get/set methods for Proxy state variables.
"""
# Note: mypy does not recognize when we fetch ProxyStateVariables.annotations.keys(), so we also need to add the valid keys here
valid_keys_literal = Literal["spend_logs_row_count"]
def __init__(self) -> None:
self.proxy_state_variables: ProxyStateVariables = ProxyStateVariables(
spend_logs_row_count=0,
)
def get_proxy_state_variable(
self,
variable_name: valid_keys_literal,
) -> Any:
return self.proxy_state_variables.get(variable_name, None)
def set_proxy_state_variable(
self,
variable_name: valid_keys_literal,
value: Any,
) -> None:
self.proxy_state_variables[variable_name] = value

View file

@ -15,6 +15,7 @@ from fastapi.responses import RedirectResponse
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
from litellm.proxy._types import (
LitellmUserRoles,
NewUserRequest,
@ -640,12 +641,15 @@ async def insert_sso_user(
dependencies=[Depends(user_api_key_auth)],
)
async def get_ui_settings(request: Request):
from litellm.proxy.proxy_server import general_settings
from litellm.proxy.proxy_server import general_settings, proxy_state
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
_logout_url = os.getenv("PROXY_LOGOUT_URL", None)
_is_sso_enabled = _has_user_setup_sso()
disable_expensive_db_queries = (
proxy_state.get_proxy_state_variable("spend_logs_row_count")
> MAX_SPENDLOG_ROWS_TO_QUERY
)
default_team_disabled = general_settings.get("default_team_disabled", False)
if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ:
if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true":
@ -656,4 +660,8 @@ async def get_ui_settings(request: Request):
"PROXY_LOGOUT_URL": _logout_url,
"DEFAULT_TEAM_DISABLED": default_team_disabled,
"SSO_ENABLED": _is_sso_enabled,
"NUM_SPEND_LOGS_ROWS": proxy_state.get_proxy_state_variable(
"spend_logs_row_count"
),
"DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries,
}

View file

@ -164,6 +164,7 @@ from litellm.proxy.common_utils.load_config_utils import (
from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment,
)
from litellm.proxy.common_utils.proxy_state import ProxyState
from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
@ -327,6 +328,7 @@ premium_user: bool = _license_check.is_premium()
global_max_parallel_request_retries_env: Optional[str] = os.getenv(
"LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES"
)
proxy_state = ProxyState()
if global_max_parallel_request_retries_env is None:
global_max_parallel_request_retries: int = 3
else:
@ -3047,6 +3049,10 @@ class ProxyStartupEvent:
prisma_client.check_view_exists()
) # check if all necessary views exist. Don't block execution
asyncio.create_task(
prisma_client._set_spend_logs_row_count_in_proxy_state()
) # set the spend logs row count in proxy state. Don't block execution
# run a health check to ensure the DB is ready
await prisma_client.health_check()
return prisma_client

View file

@ -2183,6 +2183,35 @@ class PrismaClient:
)
raise e
async def _get_spend_logs_row_count(self) -> int:
try:
sql_query = """
SELECT reltuples::BIGINT
FROM pg_class
WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
"""
result = await self.db.query_raw(query=sql_query)
return result[0]["reltuples"]
except Exception as e:
verbose_proxy_logger.error(
f"Error getting LiteLLM_SpendLogs row count: {e}"
)
return 0
async def _set_spend_logs_row_count_in_proxy_state(self) -> None:
"""
Set the `LiteLLM_SpendLogs`row count in proxy state.
This is used later to determine if we should run expensive UI Usage queries.
"""
from litellm.proxy.proxy_server import proxy_state
_num_spend_logs_rows = await self._get_spend_logs_row_count()
proxy_state.set_proxy_state_variable(
variable_name="spend_logs_row_count",
value=_num_spend_logs_rows,
)
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:

View file

@ -10,6 +10,7 @@ import litellm.proxy.proxy_server
load_dotenv()
import io
import json
import os
# this file is to test litellm/proxy
@ -2064,7 +2065,7 @@ async def test_proxy_model_group_info_rerank(prisma_client):
@pytest.mark.asyncio
async def test_proxy_server_prisma_setup():
from litellm.proxy.proxy_server import ProxyStartupEvent
from litellm.proxy.proxy_server import ProxyStartupEvent, proxy_state
from litellm.proxy.utils import ProxyLogging
from litellm.caching import DualCache
@ -2077,6 +2078,9 @@ async def test_proxy_server_prisma_setup():
mock_client.connect = AsyncMock() # Mock the connect method
mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method
mock_client.health_check = AsyncMock() # Mock the health_check method
mock_client._set_spend_logs_row_count_in_proxy_state = (
AsyncMock()
) # Mock the _set_spend_logs_row_count_in_proxy_state method
await ProxyStartupEvent._setup_prisma_client(
database_url=os.getenv("DATABASE_URL"),
@ -2092,6 +2096,10 @@ async def test_proxy_server_prisma_setup():
# This is how we ensure the DB is ready before proceeding
mock_client.health_check.assert_called_once()
# check that the spend logs row count is set in proxy state
mock_client._set_spend_logs_row_count_in_proxy_state.assert_called_once()
assert proxy_state.get_proxy_state_variable("spend_logs_row_count") is not None
@pytest.mark.asyncio
async def test_proxy_server_prisma_setup_invalid_db():
@ -2125,3 +2133,57 @@ async def test_proxy_server_prisma_setup_invalid_db():
if _old_db_url:
os.environ["DATABASE_URL"] = _old_db_url
@pytest.mark.asyncio
async def test_get_ui_settings_spend_logs_threshold():
"""
Test that get_ui_settings correctly sets DISABLE_EXPENSIVE_DB_QUERIES based on spend_logs_row_count threshold
"""
from litellm.proxy.management_endpoints.ui_sso import get_ui_settings
from litellm.proxy.proxy_server import proxy_state
from fastapi import Request
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
# Create a mock request
mock_request = Request(
scope={
"type": "http",
"headers": [],
"method": "GET",
"scheme": "http",
"server": ("testserver", 80),
"path": "/sso/get/ui_settings",
"query_string": b"",
}
)
# Test case 1: When spend_logs_row_count > MAX_SPENDLOG_ROWS_TO_QUERY
proxy_state.set_proxy_state_variable(
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY + 1
)
response = await get_ui_settings(mock_request)
print("response from get_ui_settings", json.dumps(response, indent=4))
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is True
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY + 1
# Test case 2: When spend_logs_row_count < MAX_SPENDLOG_ROWS_TO_QUERY
proxy_state.set_proxy_state_variable(
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY - 1
)
response = await get_ui_settings(mock_request)
print("response from get_ui_settings", json.dumps(response, indent=4))
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY - 1
# Test case 3: Edge case - exactly MAX_SPENDLOG_ROWS_TO_QUERY
proxy_state.set_proxy_state_variable(
"spend_logs_row_count", MAX_SPENDLOG_ROWS_TO_QUERY
)
response = await get_ui_settings(mock_request)
print("response from get_ui_settings", json.dumps(response, indent=4))
assert response["DISABLE_EXPENSIVE_DB_QUERIES"] is False
assert response["NUM_SPEND_LOGS_ROWS"] == MAX_SPENDLOG_ROWS_TO_QUERY
# Clean up
proxy_state.set_proxy_state_variable("spend_logs_row_count", 0)

View file

@ -17,7 +17,7 @@ import {
userCreateCall,
modelAvailableCall,
invitationCreateCall,
getProxyBaseUrlAndLogoutUrl,
getProxyUISettings,
} from "./networking";
const { Option } = Select;
@ -82,7 +82,7 @@ const Createuser: React.FC<CreateuserProps> = ({
setUserModels(availableModels);
// get ui settings
const uiSettingsResponse = await getProxyBaseUrlAndLogoutUrl(accessToken);
const uiSettingsResponse = await getProxyUISettings(accessToken);
console.log("uiSettingsResponse:", uiSettingsResponse);
setUISettings(uiSettingsResponse);

View file

@ -1,7 +1,7 @@
import React, { useState, useEffect } from "react";
import { Select, SelectItem, Text, Title } from "@tremor/react";
import { ProxySettings, UserInfo } from "./user_dashboard";
import { getProxyBaseUrlAndLogoutUrl } from "./networking"
import { getProxyUISettings } from "./networking"
interface DashboardTeamProps {
teams: Object[] | null;
@ -39,7 +39,7 @@ const DashboardTeam: React.FC<DashboardTeamProps> = ({
const getProxySettings = async () => {
if (proxySettings === null && accessToken) {
const proxy_settings: ProxySettings = await getProxyBaseUrlAndLogoutUrl(accessToken);
const proxy_settings: ProxySettings = await getProxyUISettings(accessToken);
setProxySettings(proxy_settings);
}
};

View file

@ -2818,7 +2818,7 @@ export const healthCheckCall = async (accessToken: String) => {
}
};
export const getProxyBaseUrlAndLogoutUrl = async (
export const getProxyUISettings = async (
accessToken: String,
) => {
/**

View file

@ -3,6 +3,7 @@ import { BarChart, BarList, Card, Title, Table, TableHead, TableHeaderCell, Tabl
import React, { useState, useEffect } from "react";
import ViewUserSpend from "./view_user_spend";
import { ProxySettings } from "./user_dashboard";
import {
Grid, Col, Text,
LineChart, TabPanel, TabPanels,
@ -35,6 +36,7 @@ import {
adminspendByProvider,
adminGlobalActivity,
adminGlobalActivityPerModel,
getProxyUISettings
} from "./networking";
import { start } from "repl";
console.log("process.env.NODE_ENV", process.env.NODE_ENV);
@ -166,6 +168,7 @@ const UsagePage: React.FC<UsagePageProps> = ({
from: new Date(Date.now() - 7 * 24 * 60 * 60 * 1000),
to: new Date(),
});
const [proxySettings, setProxySettings] = useState<ProxySettings | null>(null);
const firstDay = new Date(
currentDate.getFullYear(),
@ -194,6 +197,22 @@ const UsagePage: React.FC<UsagePageProps> = ({
return formatter.format(number);
}
useEffect(() => {
const fetchProxySettings = async () => {
if (accessToken) {
try {
const proxy_settings: ProxySettings = await getProxyUISettings(accessToken);
console.log("usage tab: proxy_settings", proxy_settings);
setProxySettings(proxy_settings);
} catch (error) {
console.error("Error fetching proxy settings:", error);
}
}
};
fetchProxySettings();
}, [accessToken]);
useEffect(() => {
updateTagSpendData(dateValue.from, dateValue.to);
}, [dateValue, selectedTags]);
@ -386,6 +405,9 @@ const UsagePage: React.FC<UsagePageProps> = ({
useEffect(() => {
if (accessToken && token && userRole && userID) {
if (proxySettings?.DISABLE_EXPENSIVE_DB_QUERIES) {
return; // Don't run expensive queries
}
fetchOverallSpend();
@ -405,9 +427,29 @@ const UsagePage: React.FC<UsagePageProps> = ({
}, [accessToken, token, userRole, userID, startTime, endTime]);
if (proxySettings?.DISABLE_EXPENSIVE_DB_QUERIES) {
return (
<div style={{ width: "100%" }} className="p-8">
<Card>
<Title>Database Query Limit Reached</Title>
<Text className="mt-4">
SpendLogs in DB has {proxySettings.NUM_SPEND_LOGS_ROWS} rows.
<br></br>
Please follow our guide to view usage when SpendLogs has more than 1M rows.
</Text>
<Button className="mt-4">
<a href="https://docs.litellm.ai/docs/proxy/spending_monitoring" target="_blank">
View Usage Guide
</a>
</Button>
</Card>
</div>
);
}
return (
<div style={{ width: "100%" }} className="p-8">
<div style={{ width: "100%" }} className="p-8">
<TabGroup>
<TabList className="mt-2">
<Tab>All Up</Tab>

View file

@ -4,7 +4,7 @@ import {
userInfoCall,
modelAvailableCall,
getTotalSpendCall,
getProxyBaseUrlAndLogoutUrl,
getProxyUISettings,
} from "./networking";
import { Grid, Col, Card, Text, Title } from "@tremor/react";
import CreateKey from "./create_key_button";
@ -28,6 +28,8 @@ export interface ProxySettings {
PROXY_LOGOUT_URL: string | null;
DEFAULT_TEAM_DISABLED: boolean;
SSO_ENABLED: boolean;
DISABLE_EXPENSIVE_DB_QUERIES: boolean;
NUM_SPEND_LOGS_ROWS: number;
}
@ -172,7 +174,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
} else {
const fetchData = async () => {
try {
const proxy_settings: ProxySettings = await getProxyBaseUrlAndLogoutUrl(accessToken);
const proxy_settings: ProxySettings = await getProxyUISettings(accessToken);
setProxySettings(proxy_settings);
const response = await userInfoCall(