(UI) - allow assigning wildcard models to a team / key (#8041)

* fix message.error

* fix add return_wildcard_routes

* ui edit modelAvailableCall

* fetchAvailableModelsForTeamOrKey

* ui set all models for a team

* ui define common helpers

* edit create key button

* fix viewing model display names

* fix editing team models

* update gitignore

* add jest testing for ui

* Revert "add jest testing for ui"

This reverts commit 98f9a3ebfd.
This commit is contained in:
Ishaan Jaff 2025-01-27 18:06:22 -08:00 committed by GitHub
parent 3a4f5b23b5
commit 7f2742334c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 142 additions and 41 deletions

2
.gitignore vendored
View file

@ -48,7 +48,7 @@ deploy/charts/litellm/charts/*
deploy/charts/*.tgz
litellm/proxy/vertex_key.json
**/.vim/
/node_modules
**/node_modules
kub.yaml
loadtest_kub.yaml
litellm/proxy/_new_secret_config.yaml

View file

@ -1,6 +1,6 @@
# What is this?
## Common checks for /v1/models and `/model/info`
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Set
import litellm
from litellm._logging import verbose_proxy_logger
@ -11,6 +11,11 @@ from litellm.utils import get_valid_models
def _check_wildcard_routing(model: str) -> bool:
"""
Returns True if a model is a provider wildcard.
eg:
- anthropic/*
- openai/*
- *
"""
if model == "*":
return True
@ -119,6 +124,7 @@ def get_complete_model_list(
proxy_model_list: List[str],
user_model: Optional[str],
infer_model_from_keys: Optional[bool],
return_wildcard_routes: Optional[bool] = False,
) -> List[str]:
"""Logic for returning complete model list for a given key + team pair"""
@ -128,7 +134,7 @@ def get_complete_model_list(
If list contains wildcard -> return known provider models
"""
unique_models = set()
unique_models: Set[str] = set()
if key_models:
unique_models.update(key_models)
elif team_models:
@ -143,10 +149,26 @@ def get_complete_model_list(
valid_models = get_valid_models()
unique_models.update(valid_models)
all_wildcard_models = _get_wildcard_models(
unique_models=unique_models, return_wildcard_routes=return_wildcard_routes
)
return list(unique_models) + all_wildcard_models
def _get_wildcard_models(
unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
) -> List[str]:
models_to_remove = set()
all_wildcard_models = []
for model in unique_models:
if _check_wildcard_routing(model=model):
if (
return_wildcard_routes is True
): # will add the wildcard route to the list eg: anthropic/*.
all_wildcard_models.append(model)
provider = model.split("/")[0]
# get all known provider models
wildcard_models = get_provider_models(provider=provider)
@ -157,4 +179,4 @@ def get_complete_model_list(
for model in models_to_remove:
unique_models.remove(model)
return list(unique_models) + all_wildcard_models
return all_wildcard_models

View file

@ -3322,6 +3322,7 @@ class ProxyStartupEvent:
) # if project requires model list
async def model_list(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
return_wildcard_routes: Optional[bool] = False,
):
"""
Use `/model/info` - to get detailed model information, example - pricing, mode, etc.
@ -3354,6 +3355,7 @@ async def model_list(
proxy_model_list=proxy_model_list,
user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
return_wildcard_routes=return_wildcard_routes,
)
return dict(
data=[

View file

@ -23,6 +23,7 @@ import {
message,
Radio,
} from "antd";
import { unfurlWildcardModelsInList, getModelDisplayName } from "./key_team_helpers/fetch_available_models_team_key";
import {
keyCreateCall,
slackBudgetAlertsHealthCheck,
@ -80,8 +81,8 @@ const CreateKey: React.FC<CreateKeyProps> = ({
const [isModalVisible, setIsModalVisible] = useState(false);
const [apiKey, setApiKey] = useState(null);
const [softBudget, setSoftBudget] = useState(null);
const [userModels, setUserModels] = useState([]);
const [modelsToPick, setModelsToPick] = useState([]);
const [userModels, setUserModels] = useState<string[]>([]);
const [modelsToPick, setModelsToPick] = useState<string[]>([]);
const [keyOwner, setKeyOwner] = useState("you");
const [predefinedTags, setPredefinedTags] = useState(getPredefinedTags(data));
const [guardrailsList, setGuardrailsList] = useState<string[]>([]);
@ -213,6 +214,8 @@ const CreateKey: React.FC<CreateKeyProps> = ({
tempModelsToPick = userModels;
}
tempModelsToPick = unfurlWildcardModelsInList(tempModelsToPick, userModels);
setModelsToPick(tempModelsToPick);
}, [team, userModels]);
@ -310,7 +313,7 @@ const CreateKey: React.FC<CreateKeyProps> = ({
</Option>
{modelsToPick.map((model: string) => (
<Option key={model} value={model}>
{model}
{getModelDisplayName(model)}
</Option>
))}
</Select>

View file

@ -0,0 +1,82 @@
import { modelAvailableCall } from "../networking";
export const fetchAvailableModelsForTeamOrKey = async (
userID: string,
userRole: string,
accessToken: string,
): Promise<string[] | undefined> => {
try {
if (userID === null || userRole === null) {
return;
}
if (accessToken !== null) {
const model_available = await modelAvailableCall(
accessToken,
userID,
userRole,
true
);
let available_model_names = model_available["data"].map(
(element: { id: string }) => element.id
);
// Group and sort models
const providerModels: string[] = [];
const specificModels: string[] = [];
available_model_names.forEach((model: string) => {
if (model.endsWith('/*')) {
providerModels.push(model);
} else {
specificModels.push(model);
}
});
// Combine arrays with provider models first
return [...providerModels, ...specificModels];
}
} catch (error) {
console.error("Error fetching user models:", error);
}
};
export const getModelDisplayName = (model: string) => {
console.log("getModelDisplayName", model);
if (model.endsWith('/*')) {
const provider = model.replace('/*', '');
return `All ${provider} models`;
}
return model;
};
export const unfurlWildcardModelsInList = (teamModels: string[], allModels: string[]): string[] => {
const wildcardDisplayNames: string[] = [];
const expandedModels: string[] = [];
console.log("teamModels", teamModels);
console.log("allModels", allModels);
teamModels.forEach(teamModel => {
if (teamModel.endsWith('/*')) {
// Extract the provider prefix (e.g., 'openai' from 'openai/*')
const provider = teamModel.replace('/*', '');
// Find all models that start with this provider
const matchingModels = allModels.filter(model =>
model.startsWith(provider + '/')
);
expandedModels.push(...matchingModels);
wildcardDisplayNames.push(teamModel);
}
else {
expandedModels.push(teamModel);
}
});
// Combine arrays with wildcard display names first, then remove duplicates
return [...wildcardDisplayNames, ...expandedModels].filter((item, index, array) =>
array.indexOf(item) === index
);
};

View file

@ -1318,7 +1318,8 @@ export const modelExceptionsCall = async (
export const modelAvailableCall = async (
accessToken: String,
userID: String,
userRole: String
userRole: String,
return_wildcard_routes: boolean = false
) => {
/**
* Get all the models user has access to
@ -1326,6 +1327,9 @@ export const modelAvailableCall = async (
console.log("in /models calls, globalLitellmHeaderName", globalLitellmHeaderName)
try {
let url = proxyBaseUrl ? `${proxyBaseUrl}/models` : `/models`;
if (return_wildcard_routes === true) {
url += `?return_wildcard_routes=True`;
}
//message.info("Requesting model data");
const response = await fetch(url, {

View file

@ -43,7 +43,6 @@ const AvailableTeamsPanel: React.FC<AvailableTeamsProps> = ({
setAvailableTeams(response);
} catch (error) {
console.error('Error fetching available teams:', error);
message.error('Failed to load available teams');
}
};

View file

@ -21,6 +21,7 @@ import {
message,
Tooltip
} from "antd";
import { fetchAvailableModelsForTeamOrKey, getModelDisplayName } from "./key_team_helpers/fetch_available_models_team_key";
import { Select, SelectItem } from "@tremor/react";
import { InfoCircleOutlined } from '@ant-design/icons';
import { getGuardrailsList } from "./networking";
@ -134,7 +135,7 @@ const Team: React.FC<TeamProps> = ({
const [isTeamModalVisible, setIsTeamModalVisible] = useState(false);
const [isAddMemberModalVisible, setIsAddMemberModalVisible] = useState(false);
const [isEditMemberModalVisible, setIsEditMemberModalVisible] = useState(false);
const [userModels, setUserModels] = useState([]);
const [userModels, setUserModels] = useState<string[]>([]);
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
const [teamToDelete, setTeamToDelete] = useState<string | null>(null);
const [selectedEditMember, setSelectedEditMember] = useState<null | TeamMember>(null);
@ -236,7 +237,7 @@ const Team: React.FC<TeamProps> = ({
{userModels &&
userModels.map((model) => (
<Select2.Option key={model} value={model}>
{model}
{getModelDisplayName(model)}
</Select2.Option>
))}
</Select2>
@ -405,21 +406,12 @@ const Team: React.FC<TeamProps> = ({
useEffect(() => {
const fetchUserModels = async () => {
try {
if (userID === null || userRole === null) {
if (userID === null || userRole === null || accessToken === null) {
return;
}
if (accessToken !== null) {
const model_available = await modelAvailableCall(
accessToken,
userID,
userRole
);
let available_model_names = model_available["data"].map(
(element: { id: string }) => element.id
);
console.log("available_model_names:", available_model_names);
setUserModels(available_model_names);
const models = await fetchAvailableModelsForTeamOrKey(userID, userRole, accessToken);
if (models) {
setUserModels(models);
}
} catch (error) {
console.error("Error fetching user models:", error);
@ -715,8 +707,8 @@ const Team: React.FC<TeamProps> = ({
>
<Text>
{model.length > 30
? `${model.slice(0, 30)}...`
: model}
? `${getModelDisplayName(model).slice(0, 30)}...`
: getModelDisplayName(model)}
</Text>
</Badge>
)
@ -875,7 +867,7 @@ const Team: React.FC<TeamProps> = ({
</Select2.Option>
{userModels.map((model) => (
<Select2.Option key={model} value={model}>
{model}
{getModelDisplayName(model)}
</Select2.Option>
))}
</Select2>

View file

@ -27,6 +27,7 @@ import {
Textarea,
} from "@tremor/react";
import { InfoCircleOutlined } from '@ant-design/icons';
import { fetchAvailableModelsForTeamOrKey, getModelDisplayName } from "./key_team_helpers/fetch_available_models_team_key";
import { Select as Select3, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react";
import {
Button as Button2,
@ -121,7 +122,7 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
const [editModalVisible, setEditModalVisible] = useState(false);
const [infoDialogVisible, setInfoDialogVisible] = useState(false);
const [selectedToken, setSelectedToken] = useState<ItemData | null>(null);
const [userModels, setUserModels] = useState([]);
const [userModels, setUserModels] = useState<string[]>([]);
const initialKnownTeamIDs: Set<string> = new Set();
const [modelLimitModalVisible, setModelLimitModalVisible] = useState(false);
const [regenerateDialogVisible, setRegenerateDialogVisible] = useState(false);
@ -238,17 +239,13 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
useEffect(() => {
const fetchUserModels = async () => {
try {
if (userID === null) {
if (userID === null || userRole === null || accessToken === null) {
return;
}
if (accessToken !== null && userRole !== null) {
const model_available = await modelAvailableCall(accessToken, userID, userRole);
let available_model_names = model_available["data"].map(
(element: { id: string }) => element.id
);
console.log("available_model_names:", available_model_names);
setUserModels(available_model_names);
const models = await fetchAvailableModelsForTeamOrKey(userID, userRole, accessToken);
if (models) {
setUserModels(models);
}
} catch (error) {
console.error("Error fetching user models:", error);
@ -424,20 +421,20 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
keyTeam.models.includes("all-proxy-models") ? (
userModels.filter(model => model !== "all-proxy-models").map((model: string) => (
<Option key={model} value={model}>
{model}
{getModelDisplayName(model)}
</Option>
))
) : (
keyTeam.models.map((model: string) => (
<Option key={model} value={model}>
{model}
{getModelDisplayName(model)}
</Option>
))
)
) : (
userModels.map((model: string) => (
<Option key={model} value={model}>
{model}
{getModelDisplayName(model)}
</Option>
))
)}
@ -1095,7 +1092,7 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
</Badge>
) : (
<Badge key={index} size={"xs"} className="mb-1" color="blue">
<Text>{model.length > 30 ? `${model.slice(0, 30)}...` : model}</Text>
<Text>{model.length > 30 ? `${getModelDisplayName(model).slice(0, 30)}...` : getModelDisplayName(model)}</Text>
</Badge>
)
))
@ -1118,7 +1115,7 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
</Badge>
) : (
<Badge key={index} size={"xs"} className="mb-1" color="blue">
<Text>{model.length > 30 ? `${model.slice(0, 30)}...` : model}</Text>
<Text>{model.length > 30 ? `${getModelDisplayName(model).slice(0, 30)}...` : getModelDisplayName(model)}</Text>
</Badge>
)
))