(UI) Refactor Add Models for Specific Teams (#8592)

* ui - use common team dropdown component

* re-use team component

* rename org field on add model

* handle add model submit

* working view model_id and team_id on root models page

* cleaner

* show all fields

* working model info view

* working team info selector

* clean up team id

* new component for model dashboard

* ui show table with dropdown

* make public model names like email

* revert changes to litellm model name

* fix litellm model name

* ui fix public model

* fix mappings

* fix conditional text input

* fix message

* ui fix bulk add models
This commit is contained in:
Ishaan Jaff 2025-02-17 17:58:29 -08:00 committed by GitHub
parent e5c7a9ea08
commit 6d8138875f
6 changed files with 246 additions and 149 deletions

View file

@ -0,0 +1,138 @@
import React from "react";
import { Card, Form, Button, Tooltip, Typography, Select as AntdSelect } from "antd";
import type { FormInstance } from "antd";
import type { UploadProps } from "antd/es/upload";
import LiteLLMModelNameField from "./litellm_model_name";
import ConditionalPublicModelName from "./conditional_public_model_name";
import ProviderSpecificFields from "./provider_specific_fields";
import AdvancedSettings from "./advanced_settings";
import { Providers, providerLogoMap, getPlaceholder } from "../provider_info_helpers";
import type { Team } from "../key_team_helpers/key_list";
interface AddModelTabProps {
form: FormInstance;
handleOk: () => void;
selectedProvider: Providers;
setSelectedProvider: (provider: Providers) => void;
providerModels: string[];
setProviderModelsFn: (provider: Providers) => void;
getPlaceholder: (provider: Providers) => string;
uploadProps: UploadProps;
showAdvancedSettings: boolean;
setShowAdvancedSettings: (show: boolean) => void;
teams: Team[] | null;
}
const { Title, Link } = Typography;
const AddModelTab: React.FC<AddModelTabProps> = ({
form,
handleOk,
selectedProvider,
setSelectedProvider,
providerModels,
setProviderModelsFn,
getPlaceholder,
uploadProps,
showAdvancedSettings,
setShowAdvancedSettings,
teams,
}) => {
return (
<>
<Title level={2}>Add new model</Title>
<Card>
<Form
form={form}
onFinish={handleOk}
labelCol={{ span: 10 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
>
<>
{/* Provider Selection */}
<Form.Item
rules={[{ required: true, message: "Required" }]}
label="Provider:"
name="custom_llm_provider"
tooltip="E.g. OpenAI, Azure OpenAI, Anthropic, Bedrock, etc."
labelCol={{ span: 10 }}
labelAlign="left"
>
<AntdSelect
showSearch={true}
value={selectedProvider}
onChange={(value) => {
setSelectedProvider(value);
setProviderModelsFn(value);
form.setFieldsValue({
model: [],
model_name: undefined
});
}}
>
{Object.entries(Providers).map(([providerEnum, providerDisplayName]) => (
<AntdSelect.Option
key={providerEnum}
value={providerEnum}
>
<div className="flex items-center space-x-2">
<img
src={providerLogoMap[providerDisplayName]}
alt={`${providerEnum} logo`}
className="w-5 h-5"
onError={(e) => {
// Create a div with provider initial as fallback
const target = e.target as HTMLImageElement;
const parent = target.parentElement;
if (parent) {
const fallbackDiv = document.createElement('div');
fallbackDiv.className = 'w-5 h-5 rounded-full bg-gray-200 flex items-center justify-center text-xs';
fallbackDiv.textContent = providerDisplayName.charAt(0);
parent.replaceChild(fallbackDiv, target);
}
}}
/>
<span>{providerDisplayName}</span>
</div>
</AntdSelect.Option>
))}
</AntdSelect>
</Form.Item>
<LiteLLMModelNameField
selectedProvider={selectedProvider}
providerModels={providerModels}
getPlaceholder={getPlaceholder}
/>
{/* Conditionally Render "Public Model Name" */}
<ConditionalPublicModelName />
<ProviderSpecificFields
selectedProvider={selectedProvider}
uploadProps={uploadProps}
/>
<AdvancedSettings
showAdvancedSettings={showAdvancedSettings}
setShowAdvancedSettings={setShowAdvancedSettings}
/>
<div className="flex justify-between items-center mb-4">
<Tooltip title="Get help on our github">
<Typography.Link href="https://github.com/BerriAI/litellm/issues">
Need Help?
</Typography.Link>
</Tooltip>
<Button htmlType="submit">Add Model</Button>
</div>
</>
</Form>
</Card>
</>
);
};
export default AddModelTab;

View file

@ -1,6 +1,6 @@
import React from "react";
import { Form } from "antd";
import { TextInput, Text } from "@tremor/react";
import React, { useEffect } from "react";
import { Form, Table, Input } from "antd";
import { Text, TextInput } from "@tremor/react";
import { Row, Col } from "antd";
const ConditionalPublicModelName: React.FC = () => {
@ -11,32 +11,61 @@ const ConditionalPublicModelName: React.FC = () => {
const selectedModels = Form.useWatch('model', form) || [];
const showPublicModelName = !selectedModels.includes('all-wildcard');
// Auto-populate model mappings when selected models change
useEffect(() => {
if (selectedModels.length > 0 && !selectedModels.includes('all-wildcard')) {
const mappings = selectedModels.map(model => ({
public_name: model,
litellm_model: model
}));
form.setFieldValue('model_mappings', mappings);
}
}, [selectedModels, form]);
if (!showPublicModelName) return null;
const columns = [
{
title: 'Public Name',
dataIndex: 'public_name',
key: 'public_name',
render: (text: string, record: any, index: number) => {
return (
<TextInput
defaultValue={text}
onChange={(e) => {
const newMappings = [...form.getFieldValue('model_mappings')];
newMappings[index].public_name = e.target.value;
form.setFieldValue('model_mappings', newMappings);
}}
/>
);
}
},
{
title: 'LiteLLM Model',
dataIndex: 'litellm_model',
key: 'litellm_model',
}
];
return (
<>
<Form.Item
label="Public Model Name"
name="model_name"
tooltip="Model name your users will pass in. Also used for load-balancing, LiteLLM will load balance between all models with this public name."
label="Model Mappings"
name="model_mappings"
tooltip="Map public model names to LiteLLM model names for load balancing"
labelCol={{ span: 10 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
required={false}
className="mb-0"
rules={[
({ getFieldValue }) => ({
validator(_, value) {
const selectedModels = getFieldValue('model') || [];
if (!selectedModels.includes('all-wildcard') || value) {
return Promise.resolve();
}
return Promise.reject(new Error('Public Model Name is required unless "All Models" is selected.'));
},
}),
]}
required={true}
>
<TextInput placeholder="my-gpt-4" />
<Table
dataSource={form.getFieldValue('model_mappings')}
columns={columns}
pagination={false}
size="small"
/>
</Form.Item>
<Row>
<Col span={10}></Col>

View file

@ -11,7 +11,8 @@ export const handleAddModelSubmit = async (
) => {
try {
console.log("handling submit for formValues:", formValues);
// If model_name is not provided, use provider.toLowerCase() + "/*"
// Handle wildcard case
if (formValues["model"] && formValues["model"].includes("all-wildcard")) {
const customProvider: Providers = formValues["custom_llm_provider"];
const litellm_custom_provider = provider_map[customProvider as keyof typeof Providers];
@ -19,26 +20,19 @@ export const handleAddModelSubmit = async (
formValues["model_name"] = wildcardModel;
formValues["model"] = wildcardModel;
}
/**
* For multiple litellm model names - create a separate deployment for each
* - get the list
* - iterate through it
* - create a new deployment for each
*
* For single model name -> make it a 1 item list
*/
// get the list of deployments
let deployments: Array<string> = Array.isArray(formValues["model"])
? formValues["model"]
: [formValues["model"]];
console.log(`received deployments: ${deployments}`);
console.log(`received type of deployments: ${typeof deployments}`);
deployments.forEach(async (litellm_model) => {
console.log(`litellm_model: ${litellm_model}`);
// Get model mappings
const modelMappings = formValues["model_mappings"] || [];
// Create a deployment for each mapping
for (const mapping of modelMappings) {
const litellmParamsObj: Record<string, any> = {};
const modelInfoObj: Record<string, any> = {};
// Set the model name and litellm model from the mapping
const modelName = mapping.public_name;
litellmParamsObj["model"] = mapping.litellm_model;
// Handle pricing conversion before processing other fields
if (formValues.input_cost_per_token) {
formValues.input_cost_per_token = Number(formValues.input_cost_per_token) / 1000000;
@ -49,8 +43,7 @@ export const handleAddModelSubmit = async (
// Keep input_cost_per_second as is, no conversion needed
// Iterate through the key-value pairs in formValues
litellmParamsObj["model"] = litellm_model;
let modelName: string = "";
litellmParamsObj["model"] = mapping.litellm_model;
console.log("formValues add deployment:", formValues);
for (const [key, value] of Object.entries(formValues)) {
if (value === "") {
@ -61,7 +54,7 @@ export const handleAddModelSubmit = async (
continue;
}
if (key == "model_name") {
modelName = modelName + value;
litellmParamsObj["model"] = value;
} else if (key == "custom_llm_provider") {
console.log("custom_llm_provider:", value);
const mappingResult = provider_map[value]; // Get the corresponding value from the mapping
@ -141,11 +134,10 @@ export const handleAddModelSubmit = async (
};
const response: any = await modelCreateCall(accessToken, new_model);
callback && callback()
console.log(`response for model create call: ${response["data"]}`);
});
}
callback && callback()
form.resetFields();
} catch (error) {
message.error("Failed to create model: " + error, 10);

View file

@ -5,9 +5,9 @@ import { Row, Col } from "antd";
import { Providers } from "../provider_info_helpers";
interface LiteLLMModelNameFieldProps {
selectedProvider: string;
selectedProvider: Providers;
providerModels: string[];
getPlaceholder: (provider: string) => string;
getPlaceholder: (provider: Providers) => string;
}
const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
@ -17,10 +17,21 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
}) => {
const form = Form.useFormInstance();
const handleModelChange = (value: string[]) => {
const handleModelChange = (value: string | string[]) => {
// Ensure value is always treated as an array
const values = Array.isArray(value) ? value : [value];
// If "all-wildcard" is selected, clear the model_name field
if (value.includes("all-wildcard")) {
form.setFieldsValue({ model_name: undefined });
if (values.includes("all-wildcard")) {
form.setFieldsValue({ model_name: undefined, model_mappings: [] });
} else {
// Update model mappings immediately for each selected model
const mappings = values
.map(model => ({
public_name: model,
litellm_model: model
}));
form.setFieldsValue({ model_mappings: mappings });
}
};
@ -39,9 +50,10 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
{(selectedProvider === Providers.Azure) ||
(selectedProvider === Providers.OpenAI_Compatible) ||
(selectedProvider === Providers.Ollama) ? (
<TextInput placeholder={getPlaceholder(selectedProvider.toString())} />
<TextInput placeholder={getPlaceholder(selectedProvider)} />
) : providerModels.length > 0 ? (
<AntSelect
mode="multiple"
allowClear
showSearch
placeholder="Select models"
@ -67,7 +79,7 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
style={{ width: '100%' }}
/>
) : (
<TextInput placeholder={getPlaceholder(selectedProvider.toString())} />
<TextInput placeholder={getPlaceholder(selectedProvider)} />
)}
</Form.Item>

View file

@ -104,6 +104,7 @@ import { Team } from "./key_team_helpers/key_list";
import TeamInfoView from "./team/team_info";
import { Providers, provider_map, providerLogoMap, getProviderLogoAndName, getPlaceholder, getProviderModels } from "./provider_info_helpers";
import ModelInfoView from "./model_info_view";
import AddModelTab from "./add_model/add_model_tab";
interface ModelDashboardProps {
accessToken: string | null;
@ -1046,7 +1047,6 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
.validateFields()
.then((values) => {
handleAddModelSubmit(values, accessToken, form, handleRefreshClick);
// form.resetFields();
})
.catch((error) => {
console.error("Validation failed:", error);
@ -1582,96 +1582,19 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
/>
</TabPanel>
<TabPanel className="h-full">
<Title2 level={2}>Add new model</Title2>
<Card>
<Form
<AddModelTab
form={form}
onFinish={handleOk}
labelCol={{ span: 10 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
>
<>
{/* Provider Selection */}
<Form.Item
rules={[{ required: true, message: "Required" }]}
label="Provider:"
name="custom_llm_provider"
tooltip="E.g. OpenAI, Azure OpenAI, Anthropic, Bedrock, etc."
labelCol={{ span: 10 }}
labelAlign="left"
>
<AntdSelect
showSearch={true}
value={selectedProvider}
onChange={(value) => {
setSelectedProvider(value);
setProviderModelsFn(value);
form.setFieldsValue({
model: [],
model_name: undefined
});
}}
>
{Object.entries(Providers).map(([providerEnum, providerDisplayName]) => (
<AntdSelect.Option
key={providerEnum}
value={providerEnum}
>
<div className="flex items-center space-x-2">
<img
src={providerLogoMap[providerDisplayName]}
alt={`${providerEnum} logo`}
className="w-5 h-5"
onError={(e) => {
// Create a div with provider initial as fallback
const target = e.target as HTMLImageElement;
const parent = target.parentElement;
if (parent) {
const fallbackDiv = document.createElement('div');
fallbackDiv.className = 'w-5 h-5 rounded-full bg-gray-200 flex items-center justify-center text-xs';
fallbackDiv.textContent = providerDisplayName.charAt(0);
parent.replaceChild(fallbackDiv, target);
}
}}
/>
<span>{providerDisplayName}</span>
</div>
</AntdSelect.Option>
))}
</AntdSelect>
</Form.Item>
<LiteLLMModelNameField
handleOk={handleOk}
selectedProvider={selectedProvider}
setSelectedProvider={setSelectedProvider}
providerModels={providerModels}
setProviderModelsFn={setProviderModelsFn}
getPlaceholder={getPlaceholder}
/>
{/* Conditionally Render "Public Model Name" */}
<ConditionalPublicModelName />
<ProviderSpecificFields
selectedProvider={selectedProvider}
uploadProps={uploadProps}
/>
<AdvancedSettings
showAdvancedSettings={showAdvancedSettings}
setShowAdvancedSettings={setShowAdvancedSettings}
teams={teams}
/>
<div className="flex justify-between items-center mb-4">
<Tooltip title="Get help on our github">
<Typography.Link href="https://github.com/BerriAI/litellm/issues">
Need Help?
</Typography.Link>
</Tooltip>
<Button2 htmlType="submit">Add Model</Button2>
</div>
</>
</Form>
</Card>
</TabPanel>
<TabPanel>
<Card>

View file

@ -92,7 +92,6 @@ export const modelCostMap = async (
throw error;
}
};
export const modelCreateCall = async (
accessToken: string,
formValues: Model
@ -106,7 +105,7 @@ export const modelCreateCall = async (
"Content-Type": "application/json",
},
body: JSON.stringify({
...formValues, // Include formValues in the request body
...formValues,
}),
});
@ -121,9 +120,13 @@ export const modelCreateCall = async (
const data = await response.json();
console.log("API Response:", data);
message.success(
"Model created successfully"
);
// Close any existing messages before showing new ones
message.destroy();
// Sequential success messages
message.success(`Model ${formValues.model_name} created successfully`, 2);
return data;
} catch (error) {
console.error("Failed to create key:", error);