(UI) fix adding Vertex Models (#8129)

* fix handleSubmit

* update handleAddModelSubmit

* add jest testing for ui

* add step for running ui unit tests

* add validate json step to add model

* ui jest testing fixes

* update package lock

* ci/cd run again

* fix antd import

* run jest tests first

* fix antd install

* fix ui unit tests

* fix unit test ui
This commit is contained in:
Ishaan Jaff 2025-01-30 21:11:08 -08:00 committed by GitHub
parent 8a235e7d38
commit 4005a51db2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 6814 additions and 120 deletions

View file

@ -11,6 +11,7 @@
"@cloudflare/workers-types"
],
"jsx": "react-jsx",
"jsxImportSource": "hono/jsx"
"jsxImportSource": "hono/jsx",
"skipLibCheck": true
},
}

View file

@ -11,7 +11,7 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
) # Adds the parent directory to the system path
import os

View file

@ -0,0 +1 @@
module.exports = 'test-file-stub';

View file

@ -0,0 +1,83 @@
import { handleAddModelSubmit } from '../../../ui/litellm-dashboard/src/components/add_model/handle_add_model_submit';
import { modelCreateCall } from '../../../ui/litellm-dashboard/src/components/networking';
// Mock the dependencies
const mockModelCreateCall = jest.fn().mockResolvedValue({ data: 'success' });
jest.mock('../../../ui/litellm-dashboard/src/components/networking', () => ({
modelCreateCall: async (accessToken: string, formValues: any) => mockModelCreateCall(formValues)
}));
// Also need to mock provider_map
jest.mock('../../../ui/litellm-dashboard/src/components/provider_info_helpers', () => ({
provider_map: {
'openai': 'openai'
}
}));
jest.mock('antd', () => ({
message: {
error: jest.fn()
}
}));
describe('handleAddModelSubmit', () => {
const mockForm = {
resetFields: jest.fn()
};
const mockAccessToken = 'test-token';
beforeEach(() => {
jest.clearAllMocks();
mockModelCreateCall.mockClear();
});
it('should not modify model name when all-wildcard is not selected', async () => {
const formValues = {
model: 'gpt-4',
custom_llm_provider: 'openai',
model_name: 'my-gpt4-deployment'
};
await handleAddModelSubmit(formValues, mockAccessToken, mockForm);
console.log('Expected call:', {
model_name: 'my-gpt4-deployment',
litellm_params: {
model: 'gpt-4',
custom_llm_provider: 'openai'
},
model_info: {}
});
console.log('Actual calls:', mockModelCreateCall.mock.calls);
expect(mockModelCreateCall).toHaveBeenCalledWith({
model_name: 'my-gpt4-deployment',
litellm_params: {
model: 'gpt-4',
custom_llm_provider: 'openai'
},
model_info: {}
});
expect(mockForm.resetFields).toHaveBeenCalled();
});
it('should handle all-wildcard model correctly', async () => {
const formValues = {
model: 'all-wildcard',
custom_llm_provider: 'openai',
model_name: 'my-deployment'
};
await handleAddModelSubmit(formValues, mockAccessToken, mockForm);
expect(mockModelCreateCall).toHaveBeenCalledWith({
model_name: 'openai/*',
litellm_params: {
model: 'openai/*',
custom_llm_provider: 'openai'
},
model_info: {}
});
expect(mockForm.resetFields).toHaveBeenCalled();
});
});

View file

@ -0,0 +1,18 @@
module.exports = {
preset: 'ts-jest',
testEnvironment: 'jsdom',
moduleNameMapper: {
'\\.(css|less|scss|sass)$': 'identity-obj-proxy',
'\\.(jpg|jpeg|png|gif|webp|svg)$': '<rootDir>/__mocks__/fileMock.js'
},
setupFilesAfterEnv: ['<rootDir>/jest.setup.js'],
testMatch: [
'<rootDir>/**/*.test.tsx',
'<rootDir>/**/*_test.tsx' // Added this to match your file naming
],
moduleDirectories: ['node_modules'],
testPathIgnorePatterns: ['/node_modules/'],
transform: {
'^.+\\.(ts|tsx)$': 'ts-jest'
}
}

View file

@ -0,0 +1 @@
// Add any global setup here

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,26 @@
{
"name": "ui-unit-tests",
"version": "1.0.0",
"scripts": {
"test": "jest",
"test:watch": "jest --watch"
},
"devDependencies": {
"@testing-library/react": "^14.0.0",
"@testing-library/jest-dom": "^6.0.0",
"@types/jest": "^29.5.0",
"@types/react": "^18.2.0",
"@types/react-dom": "^18.2.0",
"identity-obj-proxy": "^3.0.0",
"jest": "^29.5.0",
"jest-environment-jsdom": "^29.5.0",
"ts-jest": "^29.1.0",
"typescript": "^5.0.0"
},
"dependencies": {
"antd": "^5.12.5",
"@ant-design/icons": "^5.0.0",
"react": "^18.2.0",
"react-dom": "^18.2.0"
}
}

View file

@ -0,0 +1,30 @@
{
"compilerOptions": {
"target": "es5",
"lib": ["dom", "dom.iterable", "esnext"],
"allowJs": true,
"skipLibCheck": true,
"esModuleInterop": true,
"allowSyntheticDefaultImports": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"noFallthroughCasesInSwitch": true,
"module": "esnext",
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"jsx": "react-jsx",
"baseUrl": ".",
"paths": {
"*": ["*", "node_modules/*"]
}
},
"include": [
"**/*.ts",
"**/*.tsx"
],
"exclude": [
"node_modules"
]
}

View file

@ -16,6 +16,19 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
}) => {
const [form] = Form.useForm();
// Add validation function
const validateJSON = (_: any, value: string) => {
if (!value) {
return Promise.resolve();
}
try {
JSON.parse(value);
return Promise.resolve();
} catch (error) {
return Promise.reject('Please enter valid JSON');
}
};
const handlePassThroughChange = (checked: boolean) => {
const currentParams = form.getFieldValue('litellm_extra_params');
try {
@ -75,6 +88,7 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
name="litellm_extra_params"
tooltip="Optional litellm params used for making a litellm.completion() call."
className="mb-4 mt-4"
rules={[{ validator: validateJSON }]}
>
<TextArea
rows={4}
@ -104,6 +118,7 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
name="model_info_params"
tooltip="Optional model info params. Returned when calling `/model/info` endpoint."
className="mb-0"
rules={[{ validator: validateJSON }]}
>
<TextArea
rows={4}

View file

@ -0,0 +1,122 @@
import { message } from "antd";
import { provider_map } from "../provider_info_helpers";
import { modelCreateCall, Model } from "../networking";
export const handleAddModelSubmit = async (
formValues: Record<string, any>,
accessToken: string,
form: any
) => {
try {
console.log("handling submit for formValues:", formValues);
// If model_name is not provided, use provider.toLowerCase() + "/*"
if (formValues["model"] && formValues["model"].includes("all-wildcard")) {
const wildcardModel = formValues["custom_llm_provider"].toLowerCase() + "/*";
formValues["model_name"] = wildcardModel;
formValues["model"] = wildcardModel;
}
/**
* For multiple litellm model names - create a separate deployment for each
* - get the list
* - iterate through it
* - create a new deployment for each
*
* For single model name -> make it a 1 item list
*/
// get the list of deployments
let deployments: Array<string> = Array.isArray(formValues["model"])
? formValues["model"]
: [formValues["model"]];
console.log(`received deployments: ${deployments}`);
console.log(`received type of deployments: ${typeof deployments}`);
deployments.forEach(async (litellm_model) => {
console.log(`litellm_model: ${litellm_model}`);
const litellmParamsObj: Record<string, any> = {};
const modelInfoObj: Record<string, any> = {};
// Iterate through the key-value pairs in formValues
litellmParamsObj["model"] = litellm_model;
let modelName: string = "";
console.log("formValues add deployment:", formValues);
for (const [key, value] of Object.entries(formValues)) {
if (value === "") {
continue;
}
if (key == "model_name") {
modelName = modelName + value;
} else if (key == "custom_llm_provider") {
console.log("custom_llm_provider:", value);
const mappingResult = provider_map[value]; // Get the corresponding value from the mapping
litellmParamsObj["custom_llm_provider"] = mappingResult;
console.log("custom_llm_provider mappingResult:", mappingResult);
} else if (key == "model") {
continue;
}
// Check if key is "base_model"
else if (key === "base_model") {
// Add key-value pair to model_info dictionary
modelInfoObj[key] = value;
}
else if (key === "custom_model_name") {
litellmParamsObj["model"] = value;
} else if (key == "litellm_extra_params") {
console.log("litellm_extra_params:", value);
let litellmExtraParams = {};
if (value && value != undefined) {
try {
litellmExtraParams = JSON.parse(value);
} catch (error) {
message.error(
"Failed to parse LiteLLM Extra Params: " + error,
10
);
throw new Error("Failed to parse litellm_extra_params: " + error);
}
for (const [key, value] of Object.entries(litellmExtraParams)) {
litellmParamsObj[key] = value;
}
}
} else if (key == "model_info_params") {
console.log("model_info_params:", value);
let modelInfoParams = {};
if (value && value != undefined) {
try {
modelInfoParams = JSON.parse(value);
} catch (error) {
message.error(
"Failed to parse LiteLLM Extra Params: " + error,
10
);
throw new Error("Failed to parse litellm_extra_params: " + error);
}
for (const [key, value] of Object.entries(modelInfoParams)) {
modelInfoObj[key] = value;
}
}
}
// Check if key is any of the specified API related keys
else {
// Add key-value pair to litellm_params dictionary
litellmParamsObj[key] = value;
}
}
const new_model: Model = {
model_name: modelName,
litellm_params: litellmParamsObj,
model_info: modelInfoObj,
};
const response: any = await modelCreateCall(accessToken, new_model);
console.log(`response for model create call: ${response["data"]}`);
});
form.resetFields();
} catch (error) {
message.error("Failed to create model: " + error, 10);
}
};

View file

@ -19,6 +19,7 @@ import {
import ConditionalPublicModelName from "./add_model/conditional_public_model_name";
import LiteLLMModelNameField from "./add_model/litellm_model_name";
import AdvancedSettings from "./add_model/advanced_settings";
import { handleAddModelSubmit } from "./add_model/handle_add_model_submit";
import {
TabPanel,
TabPanels,
@ -152,123 +153,6 @@ const retry_policy_map: Record<string, string> = {
"InternalServerError (500)": "InternalServerErrorRetries",
};
export const handleSubmit = async (
formValues: Record<string, any>,
accessToken: string,
form: any
) => {
try {
// If model_name is not provided, use provider.toLowerCase() + "/*"
if (formValues["model"] && formValues["model"].includes("all-wildcard")) {
formValues["model_name"] = formValues["custom_llm_provider"].toLowerCase() + "/*";
}
formValues["model"] = [formValues["custom_llm_provider"].toLowerCase() + "/*"];
/**
* For multiple litellm model names - create a separate deployment for each
* - get the list
* - iterate through it
* - create a new deployment for each
*
* For single model name -> make it a 1 item list
*/
// get the list of deployments
let deployments: Array<string> = Array.isArray(formValues["model"])
? formValues["model"]
: [formValues["model"]];
console.log(`received deployments: ${deployments}`);
console.log(`received type of deployments: ${typeof deployments}`);
deployments.forEach(async (litellm_model) => {
console.log(`litellm_model: ${litellm_model}`);
const litellmParamsObj: Record<string, any> = {};
const modelInfoObj: Record<string, any> = {};
// Iterate through the key-value pairs in formValues
litellmParamsObj["model"] = litellm_model;
let modelName: string = "";
console.log("formValues add deployment:", formValues);
for (const [key, value] of Object.entries(formValues)) {
if (value === "") {
continue;
}
if (key == "model_name") {
modelName = modelName + value;
} else if (key == "custom_llm_provider") {
console.log("custom_llm_provider:", value);
const mappingResult = provider_map[value]; // Get the corresponding value from the mapping
litellmParamsObj["custom_llm_provider"] = mappingResult;
console.log("custom_llm_provider mappingResult:", mappingResult);
} else if (key == "model") {
continue;
}
// Check if key is "base_model"
else if (key === "base_model") {
// Add key-value pair to model_info dictionary
modelInfoObj[key] = value;
}
else if (key === "custom_model_name") {
litellmParamsObj["model"] = value;
} else if (key == "litellm_extra_params") {
console.log("litellm_extra_params:", value);
let litellmExtraParams = {};
if (value && value != undefined) {
try {
litellmExtraParams = JSON.parse(value);
} catch (error) {
message.error(
"Failed to parse LiteLLM Extra Params: " + error,
10
);
throw new Error("Failed to parse litellm_extra_params: " + error);
}
for (const [key, value] of Object.entries(litellmExtraParams)) {
litellmParamsObj[key] = value;
}
}
} else if (key == "model_info_params") {
console.log("model_info_params:", value);
let modelInfoParams = {};
if (value && value != undefined) {
try {
modelInfoParams = JSON.parse(value);
} catch (error) {
message.error(
"Failed to parse LiteLLM Extra Params: " + error,
10
);
throw new Error("Failed to parse litellm_extra_params: " + error);
}
for (const [key, value] of Object.entries(modelInfoParams)) {
modelInfoObj[key] = value;
}
}
}
// Check if key is any of the specified API related keys
else {
// Add key-value pair to litellm_params dictionary
litellmParamsObj[key] = value;
}
}
const new_model: Model = {
model_name: modelName,
litellm_params: litellmParamsObj,
model_info: modelInfoObj,
};
const response: any = await modelCreateCall(accessToken, new_model);
console.log(`response for model create call: ${response["data"]}`);
});
form.resetFields();
} catch (error) {
message.error("Failed to create model: " + error, 10);
}
};
const ModelDashboard: React.FC<ModelDashboardProps> = ({
accessToken,
token,
@ -1310,7 +1194,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
form
.validateFields()
.then((values) => {
handleSubmit(values, accessToken, form);
handleAddModelSubmit(values, accessToken, form);
// form.resetFields();
})
.catch((error) => {