diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index b1af153e81..47fefa048b 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -711,6 +711,11 @@ class DynamoDBArgs(LiteLLMBase): assume_role_aws_session_name: Optional[str] = None +class ConfigFieldUpdate(LiteLLMBase): + field_name: str + field_value: Any + + class ConfigGeneralSettings(LiteLLMBase): """ Documents all the fields supported by `general_settings` in config.yaml diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e9324dd96d..954b3c109a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -234,6 +234,7 @@ class SpecialModelNames(enum.Enum): class CommonProxyErrors(enum.Enum): db_not_connected_error = "DB not connected" no_llm_router = "No models configured on proxy" + not_allowed_access = "Admin-only endpoint. Not allowed to access this." @app.exception_handler(ProxyException) @@ -9389,7 +9390,7 @@ async def auth_callback(request: Request): return RedirectResponse(url=litellm_dashboard_ui) -#### BASIC ENDPOINTS #### +#### CONFIG MANAGEMENT #### @router.post( "/config/update", tags=["config.yaml"], @@ -9525,6 +9526,219 @@ async def update_config(config_info: ConfigYAML): ) +### CONFIG GENERAL SETTINGS +""" +- Update config settings +- Get config settings + +Keep it more precise, to prevent overwrite other values unintentially +""" + + +@router.post( + "/config/field/update", + tags=["config.yaml"], + dependencies=[Depends(user_api_key_auth)], +) +async def update_config_general_settings( + data: ConfigFieldUpdate, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Update a specific field in litellm general settings + """ + global prisma_client + ## VALIDATION ## + """ + - Check if prisma_client is None + - Check if user allowed to call this endpoint (admin-only) + - Check if param in general settings + - Check if config value is valid type + """ + + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + if user_api_key_dict.user_role != "proxy_admin": + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.not_allowed_access.value}, + ) + + if data.field_name not in ConfigGeneralSettings.model_fields: + raise HTTPException( + status_code=400, + detail={"error": "Invalid field={} passed in.".format(data.field_name)}, + ) + + try: + cgs = ConfigGeneralSettings(**{data.field_name: data.field_value}) + except: + raise HTTPException( + status_code=400, + detail={ + "error": "Invalid type of field value={} passed in.".format( + type(data.field_value), + ) + }, + ) + + ## get general settings from db + db_general_settings = await prisma_client.db.litellm_config.find_first( + where={"param_name": "general_settings"} + ) + ### update value + + if db_general_settings is None or db_general_settings.param_value is None: + general_settings = {} + else: + general_settings = dict(db_general_settings.param_value) + + ## update db + + general_settings[data.field_name] = data.field_value + + response = await prisma_client.db.litellm_config.upsert( + where={"param_name": "general_settings"}, + data={ + "create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore + "update": {"param_value": json.dumps(general_settings)}, # type: ignore + }, + ) + + return response + + +@router.get( + "/config/field/info", + tags=["config.yaml"], + dependencies=[Depends(user_api_key_auth)], +) +async def get_config_general_settings( + field_name: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global prisma_client + + ## VALIDATION ## + """ + - Check if prisma_client is None + - Check if user allowed to call this endpoint (admin-only) + - Check if param in general settings + """ + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + if user_api_key_dict.user_role != "proxy_admin": + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.not_allowed_access.value}, + ) + + if field_name not in ConfigGeneralSettings.model_fields: + raise HTTPException( + status_code=400, + detail={"error": "Invalid field={} passed in.".format(field_name)}, + ) + + ## get general settings from db + db_general_settings = await prisma_client.db.litellm_config.find_first( + where={"param_name": "general_settings"} + ) + ### pop the value + + if db_general_settings is None or db_general_settings.param_value is None: + raise HTTPException( + status_code=400, + detail={"error": "Field name={} not in DB".format(field_name)}, + ) + else: + general_settings = dict(db_general_settings.param_value) + + if field_name in general_settings: + return { + "field_name": field_name, + "field_value": general_settings[field_name], + } + else: + raise HTTPException( + status_code=400, + detail={"error": "Field name={} not in DB".format(field_name)}, + ) + + +@router.post( + "/config/field/delete", + tags=["config.yaml"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_config_general_settings( + field_name: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Delete the db value of this field in litellm general settings. Resets it to it's initial default value on litellm. + """ + global prisma_client + ## VALIDATION ## + """ + - Check if prisma_client is None + - Check if user allowed to call this endpoint (admin-only) + - Check if param in general settings + """ + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + if user_api_key_dict.user_role != "proxy_admin": + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.not_allowed_access.value}, + ) + + if field_name not in ConfigGeneralSettings.model_fields: + raise HTTPException( + status_code=400, + detail={"error": "Invalid field={} passed in.".format(field_name)}, + ) + + ## get general settings from db + db_general_settings = await prisma_client.db.litellm_config.find_first( + where={"param_name": "general_settings"} + ) + ### pop the value + + if db_general_settings is None or db_general_settings.param_value is None: + raise HTTPException( + status_code=400, + detail={"error": "Field name={} not in config".format(field_name)}, + ) + else: + general_settings = dict(db_general_settings.param_value) + + ## update db + + general_settings.pop(field_name) + + response = await prisma_client.db.litellm_config.upsert( + where={"param_name": "general_settings"}, + data={ + "create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore + "update": {"param_value": json.dumps(general_settings)}, # type: ignore + }, + ) + + return response + + @router.get( "/get/config/callbacks", tags=["config.yaml"], @@ -9692,6 +9906,7 @@ async def config_yaml_endpoint(config_info: ConfigYAML): return {"hello": "world"} +#### BASIC ENDPOINTS #### @router.get( "/test", tags=["health"], diff --git a/ui/litellm-dashboard/src/components/general_settings.tsx b/ui/litellm-dashboard/src/components/general_settings.tsx index c2013b1578..2063c04988 100644 --- a/ui/litellm-dashboard/src/components/general_settings.tsx +++ b/ui/litellm-dashboard/src/components/general_settings.tsx @@ -23,12 +23,30 @@ import { AccordionHeader, AccordionList, } from "@tremor/react"; -import { TabPanel, TabPanels, TabGroup, TabList, Tab, Icon } from "@tremor/react"; -import { getCallbacksCall, setCallbacksCall, serviceHealthCheck } from "./networking"; +import { + TabPanel, + TabPanels, + TabGroup, + TabList, + Tab, + Icon, +} from "@tremor/react"; +import { + getCallbacksCall, + setCallbacksCall, + serviceHealthCheck, +} from "./networking"; import { Modal, Form, Input, Select, Button as Button2, message } from "antd"; -import { InformationCircleIcon, PencilAltIcon, PencilIcon, StatusOnlineIcon, TrashIcon, RefreshIcon } from "@heroicons/react/outline"; +import { + InformationCircleIcon, + PencilAltIcon, + PencilIcon, + StatusOnlineIcon, + TrashIcon, + RefreshIcon, +} from "@heroicons/react/outline"; import StaticGenerationSearchParamsBailoutProvider from "next/dist/client/components/static-generation-searchparams-bailout-provider"; -import AddFallbacks from "./add_fallbacks" +import AddFallbacks from "./add_fallbacks"; import openai from "openai"; import Paragraph from "antd/es/skeleton/Paragraph"; @@ -36,7 +54,7 @@ interface GeneralSettingsPageProps { accessToken: string | null; userRole: string | null; userID: string | null; - modelData: any + modelData: any; } async function testFallbackModelResponse( @@ -65,43 +83,63 @@ async function testFallbackModelResponse( }, ], // @ts-ignore - mock_testing_fallbacks: true + mock_testing_fallbacks: true, }); message.success( - Test model={selectedModel}, received model={response.model}. - See window.open('https://docs.litellm.ai/docs/proxy/reliability', '_blank')} style={{ textDecoration: 'underline', color: 'blue' }}>curl + Test model={selectedModel}, received model= + {response.model}. See{" "} + + window.open( + "https://docs.litellm.ai/docs/proxy/reliability", + "_blank" + ) + } + style={{ textDecoration: "underline", color: "blue" }} + > + curl + ); } catch (error) { - message.error(`Error occurred while generating model response. Please try again. Error: ${error}`, 20); + message.error( + `Error occurred while generating model response. Please try again. Error: ${error}`, + 20 + ); } } interface AccordionHeroProps { selectedStrategy: string | null; strategyArgs: routingStrategyArgs; - paramExplanation: { [key: string]: string } + paramExplanation: { [key: string]: string }; } interface routingStrategyArgs { - ttl?: number; - lowest_latency_buffer?: number; + ttl?: number; + lowest_latency_buffer?: number; } const defaultLowestLatencyArgs: routingStrategyArgs = { - "ttl": 3600, - "lowest_latency_buffer": 0 -} + ttl: 3600, + lowest_latency_buffer: 0, +}; -export const AccordionHero: React.FC = ({ selectedStrategy, strategyArgs, paramExplanation }) => ( +export const AccordionHero: React.FC = ({ + selectedStrategy, + strategyArgs, + paramExplanation, +}) => ( - Routing Strategy Specific Args - - { - selectedStrategy == "latency-based-routing" ? - + + Routing Strategy Specific Args + + + {selectedStrategy == "latency-based-routing" ? ( + @@ -114,51 +152,69 @@ export const AccordionHero: React.FC = ({ selectedStrategy, {param} - {paramExplanation[param]} + + {paramExplanation[param]} + + name={param} + defaultValue={ + typeof value === "object" + ? JSON.stringify(value, null, 2) + : value.toString() + } + /> ))} - - : No specific settings - } - - + + ) : ( + No specific settings + )} + + ); const GeneralSettings: React.FC = ({ accessToken, userRole, userID, - modelData + modelData, }) => { - const [routerSettings, setRouterSettings] = useState<{ [key: string]: any }>({}); + const [routerSettings, setRouterSettings] = useState<{ [key: string]: any }>( + {} + ); const [isModalVisible, setIsModalVisible] = useState(false); const [form] = Form.useForm(); const [selectedCallback, setSelectedCallback] = useState(null); - const [selectedStrategy, setSelectedStrategy] = useState(null) - const [strategySettings, setStrategySettings] = useState(null); + const [selectedStrategy, setSelectedStrategy] = useState(null); + const [strategySettings, setStrategySettings] = + useState(null); let paramExplanation: { [key: string]: string } = { - "routing_strategy_args": "(dict) Arguments to pass to the routing strategy", - "routing_strategy": "(string) Routing strategy to use", - "allowed_fails": "(int) Number of times a deployment can fail before being added to cooldown", - "cooldown_time": "(int) time in seconds to cooldown a deployment after failure", - "num_retries": "(int) Number of retries for failed requests. Defaults to 0.", - "timeout": "(float) Timeout for requests. Defaults to None.", - "retry_after": "(int) Minimum time to wait before retrying a failed request", - "ttl": "(int) Sliding window to look back over when calculating the average latency of a deployment. Default - 1 hour (in seconds).", - "lowest_latency_buffer": "(float) Shuffle between deployments within this % of the lowest latency. Default - 0 (i.e. always pick lowest latency)." - } + routing_strategy_args: "(dict) Arguments to pass to the routing strategy", + routing_strategy: "(string) Routing strategy to use", + allowed_fails: + "(int) Number of times a deployment can fail before being added to cooldown", + cooldown_time: + "(int) time in seconds to cooldown a deployment after failure", + num_retries: "(int) Number of retries for failed requests. Defaults to 0.", + timeout: "(float) Timeout for requests. Defaults to None.", + retry_after: "(int) Minimum time to wait before retrying a failed request", + ttl: "(int) Sliding window to look back over when calculating the average latency of a deployment. Default - 1 hour (in seconds).", + lowest_latency_buffer: + "(float) Shuffle between deployments within this % of the lowest latency. Default - 0 (i.e. always pick lowest latency).", + }; useEffect(() => { if (!accessToken || !userRole || !userID) { @@ -190,8 +246,8 @@ const GeneralSettings: React.FC = ({ return; } - console.log(`received key: ${key}`) - console.log(`routerSettings['fallbacks']: ${routerSettings['fallbacks']}`) + console.log(`received key: ${key}`); + console.log(`routerSettings['fallbacks']: ${routerSettings["fallbacks"]}`); routerSettings["fallbacks"].map((dict: { [key: string]: any }) => { // Check if the dictionary has the specified key and delete it if present @@ -202,18 +258,18 @@ const GeneralSettings: React.FC = ({ }); const payload = { - router_settings: routerSettings + router_settings: routerSettings, }; try { await setCallbacksCall(accessToken, payload); setRouterSettings({ ...routerSettings }); - setSelectedStrategy(routerSettings["routing_strategy"]) + setSelectedStrategy(routerSettings["routing_strategy"]); message.success("Router settings updated successfully"); } catch (error) { message.error("Failed to update router settings: " + error, 20); } - } + }; const handleSaveChanges = (router_settings: any) => { if (!accessToken) { @@ -223,39 +279,55 @@ const GeneralSettings: React.FC = ({ console.log("router_settings", router_settings); const updatedVariables = Object.fromEntries( - Object.entries(router_settings).map(([key, value]) => { - if (key !== 'routing_strategy_args' && key !== "routing_strategy") { - return [key, (document.querySelector(`input[name="${key}"]`) as HTMLInputElement)?.value || value]; - } - else if (key == "routing_strategy") { - return [key, selectedStrategy] - } - else if (key == "routing_strategy_args" && selectedStrategy == "latency-based-routing") { - let setRoutingStrategyArgs: routingStrategyArgs = {} + Object.entries(router_settings) + .map(([key, value]) => { + if (key !== "routing_strategy_args" && key !== "routing_strategy") { + return [ + key, + ( + document.querySelector( + `input[name="${key}"]` + ) as HTMLInputElement + )?.value || value, + ]; + } else if (key == "routing_strategy") { + return [key, selectedStrategy]; + } else if ( + key == "routing_strategy_args" && + selectedStrategy == "latency-based-routing" + ) { + let setRoutingStrategyArgs: routingStrategyArgs = {}; - const lowestLatencyBufferElement = document.querySelector(`input[name="lowest_latency_buffer"]`) as HTMLInputElement; - const ttlElement = document.querySelector(`input[name="ttl"]`) as HTMLInputElement; + const lowestLatencyBufferElement = document.querySelector( + `input[name="lowest_latency_buffer"]` + ) as HTMLInputElement; + const ttlElement = document.querySelector( + `input[name="ttl"]` + ) as HTMLInputElement; - if (lowestLatencyBufferElement?.value) { - setRoutingStrategyArgs["lowest_latency_buffer"] = Number(lowestLatencyBufferElement.value) + if (lowestLatencyBufferElement?.value) { + setRoutingStrategyArgs["lowest_latency_buffer"] = Number( + lowestLatencyBufferElement.value + ); + } + + if (ttlElement?.value) { + setRoutingStrategyArgs["ttl"] = Number(ttlElement.value); + } + + console.log(`setRoutingStrategyArgs: ${setRoutingStrategyArgs}`); + return ["routing_strategy_args", setRoutingStrategyArgs]; } - - if (ttlElement?.value) { - setRoutingStrategyArgs["ttl"] = Number(ttlElement.value) - } - - console.log(`setRoutingStrategyArgs: ${setRoutingStrategyArgs}`) - return [ - "routing_strategy_args", setRoutingStrategyArgs - ] - } - return null; - }).filter(entry => entry !== null && entry !== undefined) as Iterable<[string, unknown]> + return null; + }) + .filter((entry) => entry !== null && entry !== undefined) as Iterable< + [string, unknown] + > ); console.log("updatedVariables", updatedVariables); const payload = { - router_settings: updatedVariables + router_settings: updatedVariables, }; try { @@ -267,117 +339,166 @@ const GeneralSettings: React.FC = ({ message.success("router settings updated successfully"); }; - - if (!accessToken) { return null; } - return ( - General Settings + Loadbalancing Fallbacks + General - - Router Settings - - - - - Setting - Value - - - - {Object.entries(routerSettings).filter(([param, value]) => param != "fallbacks" && param != "context_window_fallbacks" && param != "routing_strategy_args").map(([param, value]) => ( - - - {param} - {paramExplanation[param]} - - - { - param == "routing_strategy" ? - - usage-based-routing - latency-based-routing - simple-shuffle - : - - } - + + Router Settings + + + + + Setting + Value + + + + {Object.entries(routerSettings) + .filter( + ([param, value]) => + param != "fallbacks" && + param != "context_window_fallbacks" && + param != "routing_strategy_args" + ) + .map(([param, value]) => ( + + + {param} + + {paramExplanation[param]} + + + + {param == "routing_strategy" ? ( + + + usage-based-routing + + + latency-based-routing + + + simple-shuffle + + + ) : ( + + )} + + + ))} + + + 0 + ? routerSettings["routing_strategy_args"] + : defaultLowestLatencyArgs // default value when keys length is 0 + } + paramExplanation={paramExplanation} + /> + + + handleSaveChanges(routerSettings)} + > + Save Changes + + + + + + + + + Model Name + Fallbacks - ))} - - - 0 - ? routerSettings['routing_strategy_args'] - : defaultLowestLatencyArgs // default value when keys length is 0 - } - paramExplanation={paramExplanation} - /> - - - handleSaveChanges(routerSettings)}> - Save Changes - - - - - - - - - Model Name - Fallbacks - - + - - { - routerSettings["fallbacks"] && - routerSettings["fallbacks"].map((item: Object, index: number) => - Object.entries(item).map(([key, value]) => ( - - {key} - {Array.isArray(value) ? value.join(', ') : value} - - testFallbackModelResponse(key, accessToken)}> - Test Fallback - - - - deleteFallbacks(key)} - /> - - - )) - ) - } - - - data.model_name) : []} accessToken={accessToken} routerSettings={routerSettings} setRouterSettings={setRouterSettings}/> - - - + + {routerSettings["fallbacks"] && + routerSettings["fallbacks"].map( + (item: Object, index: number) => + Object.entries(item).map(([key, value]) => ( + + {key} + + {Array.isArray(value) ? value.join(", ") : value} + + + + testFallbackModelResponse(key, accessToken) + } + > + Test Fallback + + + + deleteFallbacks(key)} + /> + + + )) + )} + + + data.model_name) + : [] + } + accessToken={accessToken} + routerSettings={routerSettings} + setRouterSettings={setRouterSettings} + /> + + + General settings for litellm proxy + + + ); }; -export default GeneralSettings; \ No newline at end of file +export default GeneralSettings;
{paramExplanation[param]}
+ {paramExplanation[param]} +