mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
UI - fix edit azure public model name + fix editing model name post create
* test(test_router.py): add unit test confirming fallbacks with tag based routing works as expected * test: update testing * test: update test to not use gemini-pro google removed it * fix(conditional_public_model_name.tsx): edit azure public model name Fixes https://github.com/BerriAI/litellm/issues/10093 * fix(model_info_view.tsx): migrate to patch model updates Enables changing model name easily
This commit is contained in:
parent
acd2c1783c
commit
be4152c8d5
5 changed files with 136 additions and 21 deletions
|
@ -78,3 +78,43 @@ def test_router_with_model_info_and_model_group():
|
|||
model_group="gpt-3.5-turbo",
|
||||
user_facing_model_group_name="gpt-3.5-turbo",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_with_tags_and_fallbacks():
|
||||
"""
|
||||
If fallback model missing tag, raise error
|
||||
"""
|
||||
from litellm import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"mock_response": "Hello, world!",
|
||||
"tags": ["test"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "anthropic-claude-3-5-sonnet",
|
||||
"litellm_params": {
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"mock_response": "Hello, world 2!",
|
||||
},
|
||||
},
|
||||
],
|
||||
fallbacks=[
|
||||
{"gpt-3.5-turbo": ["anthropic-claude-3-5-sonnet"]},
|
||||
],
|
||||
enable_tag_filtering=True,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
mock_testing_fallbacks=True,
|
||||
metadata={"tags": ["test"]},
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ const ConditionalPublicModelName: React.FC = () => {
|
|||
const customModelName = Form.useWatch('custom_model_name', form);
|
||||
const showPublicModelName = !selectedModels.includes('all-wildcard');
|
||||
|
||||
|
||||
// Force table to re-render when custom model name changes
|
||||
useEffect(() => {
|
||||
if (customModelName && selectedModels.includes('custom')) {
|
||||
|
@ -35,20 +36,33 @@ const ConditionalPublicModelName: React.FC = () => {
|
|||
// Initial setup of model mappings when models are selected
|
||||
useEffect(() => {
|
||||
if (selectedModels.length > 0 && !selectedModels.includes('all-wildcard')) {
|
||||
const mappings = selectedModels.map((model: string) => {
|
||||
if (model === 'custom' && customModelName) {
|
||||
// Check if we already have mappings that match the selected models
|
||||
const currentMappings = form.getFieldValue('model_mappings') || [];
|
||||
|
||||
// Only update if the mappings don't exist or don't match the selected models
|
||||
const shouldUpdateMappings = currentMappings.length !== selectedModels.length ||
|
||||
!selectedModels.every(model =>
|
||||
currentMappings.some((mapping: { public_name: string; litellm_model: string }) =>
|
||||
mapping.public_name === model ||
|
||||
(model === 'custom' && mapping.public_name === customModelName)));
|
||||
|
||||
if (shouldUpdateMappings) {
|
||||
const mappings = selectedModels.map((model: string) => {
|
||||
if (model === 'custom' && customModelName) {
|
||||
return {
|
||||
public_name: customModelName,
|
||||
litellm_model: customModelName
|
||||
};
|
||||
}
|
||||
return {
|
||||
public_name: customModelName,
|
||||
litellm_model: customModelName
|
||||
public_name: model,
|
||||
litellm_model: model
|
||||
};
|
||||
}
|
||||
return {
|
||||
public_name: model,
|
||||
litellm_model: model
|
||||
};
|
||||
});
|
||||
form.setFieldValue('model_mappings', mappings);
|
||||
setTableKey(prev => prev + 1); // Force table re-render
|
||||
});
|
||||
|
||||
form.setFieldValue('model_mappings', mappings);
|
||||
setTableKey(prev => prev + 1); // Force table re-render
|
||||
}
|
||||
}
|
||||
}, [selectedModels, customModelName, form]);
|
||||
|
||||
|
|
|
@ -23,22 +23,34 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
|
|||
|
||||
// If "all-wildcard" is selected, clear the model_name field
|
||||
if (values.includes("all-wildcard")) {
|
||||
form.setFieldsValue({ model_name: undefined, model_mappings: [] });
|
||||
form.setFieldsValue({ model: undefined, model_mappings: [] });
|
||||
} else {
|
||||
// Update model mappings immediately for each selected model
|
||||
const mappings = values
|
||||
.map(model => ({
|
||||
// Get current model value to check if we need to update
|
||||
const currentModel = form.getFieldValue('model');
|
||||
|
||||
// Only update if the value has actually changed
|
||||
if (JSON.stringify(currentModel) !== JSON.stringify(values)) {
|
||||
|
||||
// Create mappings first
|
||||
const mappings = values.map(model => ({
|
||||
public_name: model,
|
||||
litellm_model: model
|
||||
}));
|
||||
form.setFieldsValue({ model_mappings: mappings });
|
||||
|
||||
// Update both fields in one call to reduce re-renders
|
||||
form.setFieldsValue({
|
||||
model: values,
|
||||
model_mappings: mappings
|
||||
});
|
||||
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Handle custom model name changes
|
||||
const handleCustomModelNameChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const customName = e.target.value;
|
||||
|
||||
|
||||
// Immediately update the model mappings
|
||||
const currentMappings = form.getFieldValue('model_mappings') || [];
|
||||
const updatedMappings = currentMappings.map((mapping: any) => {
|
||||
|
@ -69,7 +81,11 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
|
|||
{(selectedProvider === Providers.Azure) ||
|
||||
(selectedProvider === Providers.OpenAI_Compatible) ||
|
||||
(selectedProvider === Providers.Ollama) ? (
|
||||
<TextInput placeholder={getPlaceholder(selectedProvider)} />
|
||||
<>
|
||||
<TextInput
|
||||
placeholder={getPlaceholder(selectedProvider)}
|
||||
/>
|
||||
</>
|
||||
) : providerModels.length > 0 ? (
|
||||
<AntSelect
|
||||
mode="multiple"
|
||||
|
|
|
@ -15,7 +15,7 @@ import {
|
|||
} from "@tremor/react";
|
||||
import NumericalInput from "./shared/numerical_input";
|
||||
import { ArrowLeftIcon, TrashIcon, KeyIcon } from "@heroicons/react/outline";
|
||||
import { modelDeleteCall, modelUpdateCall, CredentialItem, credentialGetCall, credentialCreateCall, modelInfoCall, modelInfoV1Call } from "./networking";
|
||||
import { modelDeleteCall, modelUpdateCall, CredentialItem, credentialGetCall, credentialCreateCall, modelInfoCall, modelInfoV1Call, modelPatchUpdateCall } from "./networking";
|
||||
import { Button, Form, Input, InputNumber, message, Select, Modal } from "antd";
|
||||
import EditModelModal from "./edit_model/edit_model_modal";
|
||||
import { handleEditModelSubmit } from "./edit_model/edit_model_modal";
|
||||
|
@ -118,6 +118,8 @@ export default function ModelInfoView({
|
|||
try {
|
||||
if (!accessToken) return;
|
||||
setIsSaving(true);
|
||||
|
||||
console.log("values.model_name, ", values.model_name);
|
||||
|
||||
let updatedLitellmParams = {
|
||||
...localModelData.litellm_params,
|
||||
|
@ -149,7 +151,7 @@ export default function ModelInfoView({
|
|||
}
|
||||
};
|
||||
|
||||
await modelUpdateCall(accessToken, updateData);
|
||||
await modelPatchUpdateCall(accessToken, updateData, modelId);
|
||||
|
||||
const updatedModelData = {
|
||||
...localModelData,
|
||||
|
|
|
@ -3152,6 +3152,49 @@ export const teamUpdateCall = async (
|
|||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Patch update a model
|
||||
*
|
||||
* @param accessToken
|
||||
* @param formValues
|
||||
* @returns
|
||||
*/
|
||||
export const modelPatchUpdateCall = async (
|
||||
accessToken: string,
|
||||
formValues: Record<string, any>, // Assuming formValues is an object
|
||||
modelId: string
|
||||
) => {
|
||||
try {
|
||||
console.log("Form Values in modelUpateCall:", formValues); // Log the form values before making the API call
|
||||
|
||||
const url = proxyBaseUrl ? `${proxyBaseUrl}/model/${modelId}/update` : `/model/${modelId}/update`;
|
||||
const response = await fetch(url, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...formValues, // Include formValues in the request body
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
handleError(errorData);
|
||||
console.error("Error update from the server:", errorData);
|
||||
throw new Error("Network response was not ok");
|
||||
}
|
||||
const data = await response.json();
|
||||
console.log("Update model Response:", data);
|
||||
return data;
|
||||
// Handle success - you might want to update some state or UI based on the created key
|
||||
} catch (error) {
|
||||
console.error("Failed to update model:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const modelUpdateCall = async (
|
||||
accessToken: string,
|
||||
formValues: Record<string, any> // Assuming formValues is an object
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue