mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(UI) Refactor Add Models for Specific Teams (#8592)
* ui - use common team dropdown component * re-use team component * rename org field on add model * handle add model submit * working view model_id and team_id on root models page * cleaner * show all fields * working model info view * working team info selector * clean up team id * new component for model dashboard * ui show table with dropdown * make public model names like email * revert changes to litellm model name * fix litellm model name * ui fix public model * fix mappings * fix conditional text input * fix message * ui fix bulk add models
This commit is contained in:
parent
e5c7a9ea08
commit
6d8138875f
6 changed files with 246 additions and 149 deletions
138
ui/litellm-dashboard/src/components/add_model/add_model_tab.tsx
Normal file
138
ui/litellm-dashboard/src/components/add_model/add_model_tab.tsx
Normal file
|
@ -0,0 +1,138 @@
|
|||
import React from "react";
|
||||
import { Card, Form, Button, Tooltip, Typography, Select as AntdSelect } from "antd";
|
||||
import type { FormInstance } from "antd";
|
||||
import type { UploadProps } from "antd/es/upload";
|
||||
import LiteLLMModelNameField from "./litellm_model_name";
|
||||
import ConditionalPublicModelName from "./conditional_public_model_name";
|
||||
import ProviderSpecificFields from "./provider_specific_fields";
|
||||
import AdvancedSettings from "./advanced_settings";
|
||||
import { Providers, providerLogoMap, getPlaceholder } from "../provider_info_helpers";
|
||||
import type { Team } from "../key_team_helpers/key_list";
|
||||
|
||||
interface AddModelTabProps {
|
||||
form: FormInstance;
|
||||
handleOk: () => void;
|
||||
selectedProvider: Providers;
|
||||
setSelectedProvider: (provider: Providers) => void;
|
||||
providerModels: string[];
|
||||
setProviderModelsFn: (provider: Providers) => void;
|
||||
getPlaceholder: (provider: Providers) => string;
|
||||
uploadProps: UploadProps;
|
||||
showAdvancedSettings: boolean;
|
||||
setShowAdvancedSettings: (show: boolean) => void;
|
||||
teams: Team[] | null;
|
||||
}
|
||||
|
||||
const { Title, Link } = Typography;
|
||||
|
||||
const AddModelTab: React.FC<AddModelTabProps> = ({
|
||||
form,
|
||||
handleOk,
|
||||
selectedProvider,
|
||||
setSelectedProvider,
|
||||
providerModels,
|
||||
setProviderModelsFn,
|
||||
getPlaceholder,
|
||||
uploadProps,
|
||||
showAdvancedSettings,
|
||||
setShowAdvancedSettings,
|
||||
teams,
|
||||
}) => {
|
||||
return (
|
||||
<>
|
||||
<Title level={2}>Add new model</Title>
|
||||
<Card>
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleOk}
|
||||
labelCol={{ span: 10 }}
|
||||
wrapperCol={{ span: 16 }}
|
||||
labelAlign="left"
|
||||
>
|
||||
<>
|
||||
{/* Provider Selection */}
|
||||
<Form.Item
|
||||
rules={[{ required: true, message: "Required" }]}
|
||||
label="Provider:"
|
||||
name="custom_llm_provider"
|
||||
tooltip="E.g. OpenAI, Azure OpenAI, Anthropic, Bedrock, etc."
|
||||
labelCol={{ span: 10 }}
|
||||
labelAlign="left"
|
||||
>
|
||||
<AntdSelect
|
||||
showSearch={true}
|
||||
value={selectedProvider}
|
||||
onChange={(value) => {
|
||||
setSelectedProvider(value);
|
||||
setProviderModelsFn(value);
|
||||
form.setFieldsValue({
|
||||
model: [],
|
||||
model_name: undefined
|
||||
});
|
||||
}}
|
||||
>
|
||||
{Object.entries(Providers).map(([providerEnum, providerDisplayName]) => (
|
||||
<AntdSelect.Option
|
||||
key={providerEnum}
|
||||
value={providerEnum}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
<img
|
||||
src={providerLogoMap[providerDisplayName]}
|
||||
alt={`${providerEnum} logo`}
|
||||
className="w-5 h-5"
|
||||
onError={(e) => {
|
||||
// Create a div with provider initial as fallback
|
||||
const target = e.target as HTMLImageElement;
|
||||
const parent = target.parentElement;
|
||||
if (parent) {
|
||||
const fallbackDiv = document.createElement('div');
|
||||
fallbackDiv.className = 'w-5 h-5 rounded-full bg-gray-200 flex items-center justify-center text-xs';
|
||||
fallbackDiv.textContent = providerDisplayName.charAt(0);
|
||||
parent.replaceChild(fallbackDiv, target);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<span>{providerDisplayName}</span>
|
||||
</div>
|
||||
</AntdSelect.Option>
|
||||
))}
|
||||
</AntdSelect>
|
||||
</Form.Item>
|
||||
<LiteLLMModelNameField
|
||||
selectedProvider={selectedProvider}
|
||||
providerModels={providerModels}
|
||||
getPlaceholder={getPlaceholder}
|
||||
/>
|
||||
|
||||
{/* Conditionally Render "Public Model Name" */}
|
||||
<ConditionalPublicModelName />
|
||||
|
||||
<ProviderSpecificFields
|
||||
selectedProvider={selectedProvider}
|
||||
uploadProps={uploadProps}
|
||||
/>
|
||||
<AdvancedSettings
|
||||
showAdvancedSettings={showAdvancedSettings}
|
||||
setShowAdvancedSettings={setShowAdvancedSettings}
|
||||
/>
|
||||
|
||||
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<Tooltip title="Get help on our github">
|
||||
<Typography.Link href="https://github.com/BerriAI/litellm/issues">
|
||||
Need Help?
|
||||
</Typography.Link>
|
||||
</Tooltip>
|
||||
<Button htmlType="submit">Add Model</Button>
|
||||
</div>
|
||||
</>
|
||||
</Form>
|
||||
</Card>
|
||||
|
||||
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default AddModelTab;
|
|
@ -1,6 +1,6 @@
|
|||
import React from "react";
|
||||
import { Form } from "antd";
|
||||
import { TextInput, Text } from "@tremor/react";
|
||||
import React, { useEffect } from "react";
|
||||
import { Form, Table, Input } from "antd";
|
||||
import { Text, TextInput } from "@tremor/react";
|
||||
import { Row, Col } from "antd";
|
||||
|
||||
const ConditionalPublicModelName: React.FC = () => {
|
||||
|
@ -11,32 +11,61 @@ const ConditionalPublicModelName: React.FC = () => {
|
|||
const selectedModels = Form.useWatch('model', form) || [];
|
||||
const showPublicModelName = !selectedModels.includes('all-wildcard');
|
||||
|
||||
// Auto-populate model mappings when selected models change
|
||||
useEffect(() => {
|
||||
if (selectedModels.length > 0 && !selectedModels.includes('all-wildcard')) {
|
||||
const mappings = selectedModels.map(model => ({
|
||||
public_name: model,
|
||||
litellm_model: model
|
||||
}));
|
||||
form.setFieldValue('model_mappings', mappings);
|
||||
}
|
||||
}, [selectedModels, form]);
|
||||
|
||||
if (!showPublicModelName) return null;
|
||||
|
||||
const columns = [
|
||||
{
|
||||
title: 'Public Name',
|
||||
dataIndex: 'public_name',
|
||||
key: 'public_name',
|
||||
render: (text: string, record: any, index: number) => {
|
||||
return (
|
||||
<TextInput
|
||||
defaultValue={text}
|
||||
onChange={(e) => {
|
||||
const newMappings = [...form.getFieldValue('model_mappings')];
|
||||
newMappings[index].public_name = e.target.value;
|
||||
form.setFieldValue('model_mappings', newMappings);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: 'LiteLLM Model',
|
||||
dataIndex: 'litellm_model',
|
||||
key: 'litellm_model',
|
||||
}
|
||||
];
|
||||
|
||||
return (
|
||||
<>
|
||||
<Form.Item
|
||||
label="Public Model Name"
|
||||
name="model_name"
|
||||
tooltip="Model name your users will pass in. Also used for load-balancing, LiteLLM will load balance between all models with this public name."
|
||||
label="Model Mappings"
|
||||
name="model_mappings"
|
||||
tooltip="Map public model names to LiteLLM model names for load balancing"
|
||||
labelCol={{ span: 10 }}
|
||||
wrapperCol={{ span: 16 }}
|
||||
labelAlign="left"
|
||||
required={false}
|
||||
className="mb-0"
|
||||
rules={[
|
||||
({ getFieldValue }) => ({
|
||||
validator(_, value) {
|
||||
const selectedModels = getFieldValue('model') || [];
|
||||
if (!selectedModels.includes('all-wildcard') || value) {
|
||||
return Promise.resolve();
|
||||
}
|
||||
return Promise.reject(new Error('Public Model Name is required unless "All Models" is selected.'));
|
||||
},
|
||||
}),
|
||||
]}
|
||||
required={true}
|
||||
>
|
||||
<TextInput placeholder="my-gpt-4" />
|
||||
<Table
|
||||
dataSource={form.getFieldValue('model_mappings')}
|
||||
columns={columns}
|
||||
pagination={false}
|
||||
size="small"
|
||||
/>
|
||||
</Form.Item>
|
||||
<Row>
|
||||
<Col span={10}></Col>
|
||||
|
|
|
@ -11,7 +11,8 @@ export const handleAddModelSubmit = async (
|
|||
) => {
|
||||
try {
|
||||
console.log("handling submit for formValues:", formValues);
|
||||
// If model_name is not provided, use provider.toLowerCase() + "/*"
|
||||
|
||||
// Handle wildcard case
|
||||
if (formValues["model"] && formValues["model"].includes("all-wildcard")) {
|
||||
const customProvider: Providers = formValues["custom_llm_provider"];
|
||||
const litellm_custom_provider = provider_map[customProvider as keyof typeof Providers];
|
||||
|
@ -19,26 +20,19 @@ export const handleAddModelSubmit = async (
|
|||
formValues["model_name"] = wildcardModel;
|
||||
formValues["model"] = wildcardModel;
|
||||
}
|
||||
/**
|
||||
* For multiple litellm model names - create a separate deployment for each
|
||||
* - get the list
|
||||
* - iterate through it
|
||||
* - create a new deployment for each
|
||||
*
|
||||
* For single model name -> make it a 1 item list
|
||||
*/
|
||||
|
||||
// get the list of deployments
|
||||
let deployments: Array<string> = Array.isArray(formValues["model"])
|
||||
? formValues["model"]
|
||||
: [formValues["model"]];
|
||||
console.log(`received deployments: ${deployments}`);
|
||||
console.log(`received type of deployments: ${typeof deployments}`);
|
||||
deployments.forEach(async (litellm_model) => {
|
||||
console.log(`litellm_model: ${litellm_model}`);
|
||||
|
||||
// Get model mappings
|
||||
const modelMappings = formValues["model_mappings"] || [];
|
||||
|
||||
// Create a deployment for each mapping
|
||||
for (const mapping of modelMappings) {
|
||||
const litellmParamsObj: Record<string, any> = {};
|
||||
const modelInfoObj: Record<string, any> = {};
|
||||
|
||||
// Set the model name and litellm model from the mapping
|
||||
const modelName = mapping.public_name;
|
||||
litellmParamsObj["model"] = mapping.litellm_model;
|
||||
|
||||
// Handle pricing conversion before processing other fields
|
||||
if (formValues.input_cost_per_token) {
|
||||
formValues.input_cost_per_token = Number(formValues.input_cost_per_token) / 1000000;
|
||||
|
@ -49,8 +43,7 @@ export const handleAddModelSubmit = async (
|
|||
// Keep input_cost_per_second as is, no conversion needed
|
||||
|
||||
// Iterate through the key-value pairs in formValues
|
||||
litellmParamsObj["model"] = litellm_model;
|
||||
let modelName: string = "";
|
||||
litellmParamsObj["model"] = mapping.litellm_model;
|
||||
console.log("formValues add deployment:", formValues);
|
||||
for (const [key, value] of Object.entries(formValues)) {
|
||||
if (value === "") {
|
||||
|
@ -61,7 +54,7 @@ export const handleAddModelSubmit = async (
|
|||
continue;
|
||||
}
|
||||
if (key == "model_name") {
|
||||
modelName = modelName + value;
|
||||
litellmParamsObj["model"] = value;
|
||||
} else if (key == "custom_llm_provider") {
|
||||
console.log("custom_llm_provider:", value);
|
||||
const mappingResult = provider_map[value]; // Get the corresponding value from the mapping
|
||||
|
@ -141,11 +134,10 @@ export const handleAddModelSubmit = async (
|
|||
};
|
||||
|
||||
const response: any = await modelCreateCall(accessToken, new_model);
|
||||
callback && callback()
|
||||
|
||||
console.log(`response for model create call: ${response["data"]}`);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
callback && callback()
|
||||
form.resetFields();
|
||||
} catch (error) {
|
||||
message.error("Failed to create model: " + error, 10);
|
||||
|
|
|
@ -5,22 +5,33 @@ import { Row, Col } from "antd";
|
|||
import { Providers } from "../provider_info_helpers";
|
||||
|
||||
interface LiteLLMModelNameFieldProps {
|
||||
selectedProvider: string;
|
||||
selectedProvider: Providers;
|
||||
providerModels: string[];
|
||||
getPlaceholder: (provider: string) => string;
|
||||
getPlaceholder: (provider: Providers) => string;
|
||||
}
|
||||
|
||||
const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
|
||||
selectedProvider,
|
||||
providerModels,
|
||||
providerModels,
|
||||
getPlaceholder,
|
||||
}) => {
|
||||
const form = Form.useFormInstance();
|
||||
|
||||
const handleModelChange = (value: string[]) => {
|
||||
const handleModelChange = (value: string | string[]) => {
|
||||
// Ensure value is always treated as an array
|
||||
const values = Array.isArray(value) ? value : [value];
|
||||
|
||||
// If "all-wildcard" is selected, clear the model_name field
|
||||
if (value.includes("all-wildcard")) {
|
||||
form.setFieldsValue({ model_name: undefined });
|
||||
if (values.includes("all-wildcard")) {
|
||||
form.setFieldsValue({ model_name: undefined, model_mappings: [] });
|
||||
} else {
|
||||
// Update model mappings immediately for each selected model
|
||||
const mappings = values
|
||||
.map(model => ({
|
||||
public_name: model,
|
||||
litellm_model: model
|
||||
}));
|
||||
form.setFieldsValue({ model_mappings: mappings });
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -39,9 +50,10 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
|
|||
{(selectedProvider === Providers.Azure) ||
|
||||
(selectedProvider === Providers.OpenAI_Compatible) ||
|
||||
(selectedProvider === Providers.Ollama) ? (
|
||||
<TextInput placeholder={getPlaceholder(selectedProvider.toString())} />
|
||||
<TextInput placeholder={getPlaceholder(selectedProvider)} />
|
||||
) : providerModels.length > 0 ? (
|
||||
<AntSelect
|
||||
mode="multiple"
|
||||
allowClear
|
||||
showSearch
|
||||
placeholder="Select models"
|
||||
|
@ -67,7 +79,7 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
|
|||
style={{ width: '100%' }}
|
||||
/>
|
||||
) : (
|
||||
<TextInput placeholder={getPlaceholder(selectedProvider.toString())} />
|
||||
<TextInput placeholder={getPlaceholder(selectedProvider)} />
|
||||
)}
|
||||
</Form.Item>
|
||||
|
||||
|
|
|
@ -104,6 +104,7 @@ import { Team } from "./key_team_helpers/key_list";
|
|||
import TeamInfoView from "./team/team_info";
|
||||
import { Providers, provider_map, providerLogoMap, getProviderLogoAndName, getPlaceholder, getProviderModels } from "./provider_info_helpers";
|
||||
import ModelInfoView from "./model_info_view";
|
||||
import AddModelTab from "./add_model/add_model_tab";
|
||||
|
||||
interface ModelDashboardProps {
|
||||
accessToken: string | null;
|
||||
|
@ -1046,7 +1047,6 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
.validateFields()
|
||||
.then((values) => {
|
||||
handleAddModelSubmit(values, accessToken, form, handleRefreshClick);
|
||||
// form.resetFields();
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error("Validation failed:", error);
|
||||
|
@ -1582,96 +1582,19 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
/>
|
||||
</TabPanel>
|
||||
<TabPanel className="h-full">
|
||||
<Title2 level={2}>Add new model</Title2>
|
||||
<Card>
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleOk}
|
||||
labelCol={{ span: 10 }}
|
||||
wrapperCol={{ span: 16 }}
|
||||
labelAlign="left"
|
||||
>
|
||||
<>
|
||||
{/* Provider Selection */}
|
||||
<Form.Item
|
||||
rules={[{ required: true, message: "Required" }]}
|
||||
label="Provider:"
|
||||
name="custom_llm_provider"
|
||||
tooltip="E.g. OpenAI, Azure OpenAI, Anthropic, Bedrock, etc."
|
||||
labelCol={{ span: 10 }}
|
||||
labelAlign="left"
|
||||
>
|
||||
<AntdSelect
|
||||
showSearch={true}
|
||||
value={selectedProvider}
|
||||
onChange={(value) => {
|
||||
setSelectedProvider(value);
|
||||
setProviderModelsFn(value);
|
||||
form.setFieldsValue({
|
||||
model: [],
|
||||
model_name: undefined
|
||||
});
|
||||
}}
|
||||
>
|
||||
{Object.entries(Providers).map(([providerEnum, providerDisplayName]) => (
|
||||
<AntdSelect.Option
|
||||
key={providerEnum}
|
||||
value={providerEnum}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
<img
|
||||
src={providerLogoMap[providerDisplayName]}
|
||||
alt={`${providerEnum} logo`}
|
||||
className="w-5 h-5"
|
||||
onError={(e) => {
|
||||
// Create a div with provider initial as fallback
|
||||
const target = e.target as HTMLImageElement;
|
||||
const parent = target.parentElement;
|
||||
if (parent) {
|
||||
const fallbackDiv = document.createElement('div');
|
||||
fallbackDiv.className = 'w-5 h-5 rounded-full bg-gray-200 flex items-center justify-center text-xs';
|
||||
fallbackDiv.textContent = providerDisplayName.charAt(0);
|
||||
parent.replaceChild(fallbackDiv, target);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<span>{providerDisplayName}</span>
|
||||
</div>
|
||||
</AntdSelect.Option>
|
||||
))}
|
||||
</AntdSelect>
|
||||
</Form.Item>
|
||||
<LiteLLMModelNameField
|
||||
selectedProvider={selectedProvider}
|
||||
providerModels={providerModels}
|
||||
getPlaceholder={getPlaceholder}
|
||||
/>
|
||||
|
||||
{/* Conditionally Render "Public Model Name" */}
|
||||
<ConditionalPublicModelName />
|
||||
|
||||
<ProviderSpecificFields
|
||||
selectedProvider={selectedProvider}
|
||||
uploadProps={uploadProps}
|
||||
/>
|
||||
<AdvancedSettings
|
||||
showAdvancedSettings={showAdvancedSettings}
|
||||
setShowAdvancedSettings={setShowAdvancedSettings}
|
||||
teams={teams}
|
||||
/>
|
||||
|
||||
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<Tooltip title="Get help on our github">
|
||||
<Typography.Link href="https://github.com/BerriAI/litellm/issues">
|
||||
Need Help?
|
||||
</Typography.Link>
|
||||
</Tooltip>
|
||||
<Button2 htmlType="submit">Add Model</Button2>
|
||||
</div>
|
||||
</>
|
||||
</Form>
|
||||
</Card>
|
||||
<AddModelTab
|
||||
form={form}
|
||||
handleOk={handleOk}
|
||||
selectedProvider={selectedProvider}
|
||||
setSelectedProvider={setSelectedProvider}
|
||||
providerModels={providerModels}
|
||||
setProviderModelsFn={setProviderModelsFn}
|
||||
getPlaceholder={getPlaceholder}
|
||||
uploadProps={uploadProps}
|
||||
showAdvancedSettings={showAdvancedSettings}
|
||||
setShowAdvancedSettings={setShowAdvancedSettings}
|
||||
teams={teams}
|
||||
/>
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<Card>
|
||||
|
|
|
@ -92,7 +92,6 @@ export const modelCostMap = async (
|
|||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const modelCreateCall = async (
|
||||
accessToken: string,
|
||||
formValues: Model
|
||||
|
@ -106,7 +105,7 @@ export const modelCreateCall = async (
|
|||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...formValues, // Include formValues in the request body
|
||||
...formValues,
|
||||
}),
|
||||
});
|
||||
|
||||
|
@ -121,9 +120,13 @@ export const modelCreateCall = async (
|
|||
|
||||
const data = await response.json();
|
||||
console.log("API Response:", data);
|
||||
message.success(
|
||||
"Model created successfully"
|
||||
);
|
||||
|
||||
// Close any existing messages before showing new ones
|
||||
message.destroy();
|
||||
|
||||
// Sequential success messages
|
||||
message.success(`Model ${formValues.model_name} created successfully`, 2);
|
||||
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error("Failed to create key:", error);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue