From 6d8138875f3c476a06c8b8bb215dbe65ff8766e3 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 17 Feb 2025 17:58:29 -0800 Subject: [PATCH] (UI) Refactor Add Models for Specific Teams (#8592) * ui - use common team dropdown component * re-use team component * rename org field on add model * handle add model submit * working view model_id and team_id on root models page * cleaner * show all fields * working model info view * working team info selector * clean up team id * new component for model dashboard * ui show table with dropdown * make public model names like email * revert changes to litellm model name * fix litellm model name * ui fix public model * fix mappings * fix conditional text input * fix message * ui fix bulk add models --- .../components/add_model/add_model_tab.tsx | 138 ++++++++++++++++++ .../conditional_public_model_name.tsx | 69 ++++++--- .../add_model/handle_add_model_submit.tsx | 42 +++--- .../add_model/litellm_model_name.tsx | 28 +++- .../src/components/model_dashboard.tsx | 105 ++----------- .../src/components/networking.tsx | 13 +- 6 files changed, 246 insertions(+), 149 deletions(-) create mode 100644 ui/litellm-dashboard/src/components/add_model/add_model_tab.tsx diff --git a/ui/litellm-dashboard/src/components/add_model/add_model_tab.tsx b/ui/litellm-dashboard/src/components/add_model/add_model_tab.tsx new file mode 100644 index 0000000000..389ad0d4e2 --- /dev/null +++ b/ui/litellm-dashboard/src/components/add_model/add_model_tab.tsx @@ -0,0 +1,138 @@ +import React from "react"; +import { Card, Form, Button, Tooltip, Typography, Select as AntdSelect } from "antd"; +import type { FormInstance } from "antd"; +import type { UploadProps } from "antd/es/upload"; +import LiteLLMModelNameField from "./litellm_model_name"; +import ConditionalPublicModelName from "./conditional_public_model_name"; +import ProviderSpecificFields from "./provider_specific_fields"; +import AdvancedSettings from "./advanced_settings"; +import { Providers, providerLogoMap, getPlaceholder } from "../provider_info_helpers"; +import type { Team } from "../key_team_helpers/key_list"; + +interface AddModelTabProps { + form: FormInstance; + handleOk: () => void; + selectedProvider: Providers; + setSelectedProvider: (provider: Providers) => void; + providerModels: string[]; + setProviderModelsFn: (provider: Providers) => void; + getPlaceholder: (provider: Providers) => string; + uploadProps: UploadProps; + showAdvancedSettings: boolean; + setShowAdvancedSettings: (show: boolean) => void; + teams: Team[] | null; +} + +const { Title, Link } = Typography; + +const AddModelTab: React.FC = ({ + form, + handleOk, + selectedProvider, + setSelectedProvider, + providerModels, + setProviderModelsFn, + getPlaceholder, + uploadProps, + showAdvancedSettings, + setShowAdvancedSettings, + teams, +}) => { + return ( + <> + Add new model + +
+ <> + {/* Provider Selection */} + + { + setSelectedProvider(value); + setProviderModelsFn(value); + form.setFieldsValue({ + model: [], + model_name: undefined + }); + }} + > + {Object.entries(Providers).map(([providerEnum, providerDisplayName]) => ( + +
+ {`${providerEnum} { + // Create a div with provider initial as fallback + const target = e.target as HTMLImageElement; + const parent = target.parentElement; + if (parent) { + const fallbackDiv = document.createElement('div'); + fallbackDiv.className = 'w-5 h-5 rounded-full bg-gray-200 flex items-center justify-center text-xs'; + fallbackDiv.textContent = providerDisplayName.charAt(0); + parent.replaceChild(fallbackDiv, target); + } + }} + /> + {providerDisplayName} +
+
+ ))} +
+
+ + + {/* Conditionally Render "Public Model Name" */} + + + + + + +
+ + + Need Help? + + + +
+ + +
+ + + + ); +}; + +export default AddModelTab; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/add_model/conditional_public_model_name.tsx b/ui/litellm-dashboard/src/components/add_model/conditional_public_model_name.tsx index 85fd870aff..264968ceea 100644 --- a/ui/litellm-dashboard/src/components/add_model/conditional_public_model_name.tsx +++ b/ui/litellm-dashboard/src/components/add_model/conditional_public_model_name.tsx @@ -1,6 +1,6 @@ -import React from "react"; -import { Form } from "antd"; -import { TextInput, Text } from "@tremor/react"; +import React, { useEffect } from "react"; +import { Form, Table, Input } from "antd"; +import { Text, TextInput } from "@tremor/react"; import { Row, Col } from "antd"; const ConditionalPublicModelName: React.FC = () => { @@ -11,32 +11,61 @@ const ConditionalPublicModelName: React.FC = () => { const selectedModels = Form.useWatch('model', form) || []; const showPublicModelName = !selectedModels.includes('all-wildcard'); + // Auto-populate model mappings when selected models change + useEffect(() => { + if (selectedModels.length > 0 && !selectedModels.includes('all-wildcard')) { + const mappings = selectedModels.map(model => ({ + public_name: model, + litellm_model: model + })); + form.setFieldValue('model_mappings', mappings); + } + }, [selectedModels, form]); + if (!showPublicModelName) return null; + const columns = [ + { + title: 'Public Name', + dataIndex: 'public_name', + key: 'public_name', + render: (text: string, record: any, index: number) => { + return ( + { + const newMappings = [...form.getFieldValue('model_mappings')]; + newMappings[index].public_name = e.target.value; + form.setFieldValue('model_mappings', newMappings); + }} + /> + ); + } + }, + { + title: 'LiteLLM Model', + dataIndex: 'litellm_model', + key: 'litellm_model', + } + ]; + return ( <> ({ - validator(_, value) { - const selectedModels = getFieldValue('model') || []; - if (!selectedModels.includes('all-wildcard') || value) { - return Promise.resolve(); - } - return Promise.reject(new Error('Public Model Name is required unless "All Models" is selected.')); - }, - }), - ]} + required={true} > - + diff --git a/ui/litellm-dashboard/src/components/add_model/handle_add_model_submit.tsx b/ui/litellm-dashboard/src/components/add_model/handle_add_model_submit.tsx index 4f8a3ae607..00035652ee 100644 --- a/ui/litellm-dashboard/src/components/add_model/handle_add_model_submit.tsx +++ b/ui/litellm-dashboard/src/components/add_model/handle_add_model_submit.tsx @@ -11,7 +11,8 @@ export const handleAddModelSubmit = async ( ) => { try { console.log("handling submit for formValues:", formValues); - // If model_name is not provided, use provider.toLowerCase() + "/*" + + // Handle wildcard case if (formValues["model"] && formValues["model"].includes("all-wildcard")) { const customProvider: Providers = formValues["custom_llm_provider"]; const litellm_custom_provider = provider_map[customProvider as keyof typeof Providers]; @@ -19,26 +20,19 @@ export const handleAddModelSubmit = async ( formValues["model_name"] = wildcardModel; formValues["model"] = wildcardModel; } - /** - * For multiple litellm model names - create a separate deployment for each - * - get the list - * - iterate through it - * - create a new deployment for each - * - * For single model name -> make it a 1 item list - */ - - // get the list of deployments - let deployments: Array = Array.isArray(formValues["model"]) - ? formValues["model"] - : [formValues["model"]]; - console.log(`received deployments: ${deployments}`); - console.log(`received type of deployments: ${typeof deployments}`); - deployments.forEach(async (litellm_model) => { - console.log(`litellm_model: ${litellm_model}`); + + // Get model mappings + const modelMappings = formValues["model_mappings"] || []; + + // Create a deployment for each mapping + for (const mapping of modelMappings) { const litellmParamsObj: Record = {}; const modelInfoObj: Record = {}; + // Set the model name and litellm model from the mapping + const modelName = mapping.public_name; + litellmParamsObj["model"] = mapping.litellm_model; + // Handle pricing conversion before processing other fields if (formValues.input_cost_per_token) { formValues.input_cost_per_token = Number(formValues.input_cost_per_token) / 1000000; @@ -49,8 +43,7 @@ export const handleAddModelSubmit = async ( // Keep input_cost_per_second as is, no conversion needed // Iterate through the key-value pairs in formValues - litellmParamsObj["model"] = litellm_model; - let modelName: string = ""; + litellmParamsObj["model"] = mapping.litellm_model; console.log("formValues add deployment:", formValues); for (const [key, value] of Object.entries(formValues)) { if (value === "") { @@ -61,7 +54,7 @@ export const handleAddModelSubmit = async ( continue; } if (key == "model_name") { - modelName = modelName + value; + litellmParamsObj["model"] = value; } else if (key == "custom_llm_provider") { console.log("custom_llm_provider:", value); const mappingResult = provider_map[value]; // Get the corresponding value from the mapping @@ -141,11 +134,10 @@ export const handleAddModelSubmit = async ( }; const response: any = await modelCreateCall(accessToken, new_model); - callback && callback() - console.log(`response for model create call: ${response["data"]}`); - }); - + } + + callback && callback() form.resetFields(); } catch (error) { message.error("Failed to create model: " + error, 10); diff --git a/ui/litellm-dashboard/src/components/add_model/litellm_model_name.tsx b/ui/litellm-dashboard/src/components/add_model/litellm_model_name.tsx index 48a8bf816f..883923316b 100644 --- a/ui/litellm-dashboard/src/components/add_model/litellm_model_name.tsx +++ b/ui/litellm-dashboard/src/components/add_model/litellm_model_name.tsx @@ -5,22 +5,33 @@ import { Row, Col } from "antd"; import { Providers } from "../provider_info_helpers"; interface LiteLLMModelNameFieldProps { - selectedProvider: string; + selectedProvider: Providers; providerModels: string[]; - getPlaceholder: (provider: string) => string; + getPlaceholder: (provider: Providers) => string; } const LiteLLMModelNameField: React.FC = ({ selectedProvider, - providerModels, + providerModels, getPlaceholder, }) => { const form = Form.useFormInstance(); - const handleModelChange = (value: string[]) => { + const handleModelChange = (value: string | string[]) => { + // Ensure value is always treated as an array + const values = Array.isArray(value) ? value : [value]; + // If "all-wildcard" is selected, clear the model_name field - if (value.includes("all-wildcard")) { - form.setFieldsValue({ model_name: undefined }); + if (values.includes("all-wildcard")) { + form.setFieldsValue({ model_name: undefined, model_mappings: [] }); + } else { + // Update model mappings immediately for each selected model + const mappings = values + .map(model => ({ + public_name: model, + litellm_model: model + })); + form.setFieldsValue({ model_mappings: mappings }); } }; @@ -39,9 +50,10 @@ const LiteLLMModelNameField: React.FC = ({ {(selectedProvider === Providers.Azure) || (selectedProvider === Providers.OpenAI_Compatible) || (selectedProvider === Providers.Ollama) ? ( - + ) : providerModels.length > 0 ? ( = ({ style={{ width: '100%' }} /> ) : ( - + )} diff --git a/ui/litellm-dashboard/src/components/model_dashboard.tsx b/ui/litellm-dashboard/src/components/model_dashboard.tsx index 96be0e640d..b8074f2b81 100644 --- a/ui/litellm-dashboard/src/components/model_dashboard.tsx +++ b/ui/litellm-dashboard/src/components/model_dashboard.tsx @@ -104,6 +104,7 @@ import { Team } from "./key_team_helpers/key_list"; import TeamInfoView from "./team/team_info"; import { Providers, provider_map, providerLogoMap, getProviderLogoAndName, getPlaceholder, getProviderModels } from "./provider_info_helpers"; import ModelInfoView from "./model_info_view"; +import AddModelTab from "./add_model/add_model_tab"; interface ModelDashboardProps { accessToken: string | null; @@ -1046,7 +1047,6 @@ const ModelDashboard: React.FC = ({ .validateFields() .then((values) => { handleAddModelSubmit(values, accessToken, form, handleRefreshClick); - // form.resetFields(); }) .catch((error) => { console.error("Validation failed:", error); @@ -1582,96 +1582,19 @@ const ModelDashboard: React.FC = ({ /> - Add new model - -
- <> - {/* Provider Selection */} - - { - setSelectedProvider(value); - setProviderModelsFn(value); - form.setFieldsValue({ - model: [], - model_name: undefined - }); - }} - > - {Object.entries(Providers).map(([providerEnum, providerDisplayName]) => ( - -
- {`${providerEnum} { - // Create a div with provider initial as fallback - const target = e.target as HTMLImageElement; - const parent = target.parentElement; - if (parent) { - const fallbackDiv = document.createElement('div'); - fallbackDiv.className = 'w-5 h-5 rounded-full bg-gray-200 flex items-center justify-center text-xs'; - fallbackDiv.textContent = providerDisplayName.charAt(0); - parent.replaceChild(fallbackDiv, target); - } - }} - /> - {providerDisplayName} -
-
- ))} -
-
- - - {/* Conditionally Render "Public Model Name" */} - - - - - - -
- - - Need Help? - - - Add Model -
- - -
+
diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index a59411fa6d..96bf0222e9 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -92,7 +92,6 @@ export const modelCostMap = async ( throw error; } }; - export const modelCreateCall = async ( accessToken: string, formValues: Model @@ -106,7 +105,7 @@ export const modelCreateCall = async ( "Content-Type": "application/json", }, body: JSON.stringify({ - ...formValues, // Include formValues in the request body + ...formValues, }), }); @@ -121,9 +120,13 @@ export const modelCreateCall = async ( const data = await response.json(); console.log("API Response:", data); - message.success( - "Model created successfully" - ); + + // Close any existing messages before showing new ones + message.destroy(); + + // Sequential success messages + message.success(`Model ${formValues.model_name} created successfully`, 2); + return data; } catch (error) { console.error("Failed to create key:", error);