mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
(UI) - allow assigning wildcard models to a team / key (#8041)
* fix message.error
* fix add return_wildcard_routes
* ui edit modelAvailableCall
* fetchAvailableModelsForTeamOrKey
* ui set all models for a team
* ui define common helpers
* edit create key button
* fix viewing model display names
* fix editing team models
* update gitignore
* add jest testing for ui
* Revert "add jest testing for ui"
This reverts commit 98f9a3ebfd
.
This commit is contained in:
parent
3a4f5b23b5
commit
7f2742334c
9 changed files with 142 additions and 41 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -48,7 +48,7 @@ deploy/charts/litellm/charts/*
|
|||
deploy/charts/*.tgz
|
||||
litellm/proxy/vertex_key.json
|
||||
**/.vim/
|
||||
/node_modules
|
||||
**/node_modules
|
||||
kub.yaml
|
||||
loadtest_kub.yaml
|
||||
litellm/proxy/_new_secret_config.yaml
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# What is this?
|
||||
## Common checks for /v1/models and `/model/info`
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
@ -11,6 +11,11 @@ from litellm.utils import get_valid_models
|
|||
def _check_wildcard_routing(model: str) -> bool:
|
||||
"""
|
||||
Returns True if a model is a provider wildcard.
|
||||
|
||||
eg:
|
||||
- anthropic/*
|
||||
- openai/*
|
||||
- *
|
||||
"""
|
||||
if model == "*":
|
||||
return True
|
||||
|
@ -119,6 +124,7 @@ def get_complete_model_list(
|
|||
proxy_model_list: List[str],
|
||||
user_model: Optional[str],
|
||||
infer_model_from_keys: Optional[bool],
|
||||
return_wildcard_routes: Optional[bool] = False,
|
||||
) -> List[str]:
|
||||
"""Logic for returning complete model list for a given key + team pair"""
|
||||
|
||||
|
@ -128,7 +134,7 @@ def get_complete_model_list(
|
|||
|
||||
If list contains wildcard -> return known provider models
|
||||
"""
|
||||
unique_models = set()
|
||||
unique_models: Set[str] = set()
|
||||
if key_models:
|
||||
unique_models.update(key_models)
|
||||
elif team_models:
|
||||
|
@ -143,10 +149,26 @@ def get_complete_model_list(
|
|||
valid_models = get_valid_models()
|
||||
unique_models.update(valid_models)
|
||||
|
||||
all_wildcard_models = _get_wildcard_models(
|
||||
unique_models=unique_models, return_wildcard_routes=return_wildcard_routes
|
||||
)
|
||||
|
||||
return list(unique_models) + all_wildcard_models
|
||||
|
||||
|
||||
def _get_wildcard_models(
|
||||
unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
|
||||
) -> List[str]:
|
||||
models_to_remove = set()
|
||||
all_wildcard_models = []
|
||||
for model in unique_models:
|
||||
if _check_wildcard_routing(model=model):
|
||||
|
||||
if (
|
||||
return_wildcard_routes is True
|
||||
): # will add the wildcard route to the list eg: anthropic/*.
|
||||
all_wildcard_models.append(model)
|
||||
|
||||
provider = model.split("/")[0]
|
||||
# get all known provider models
|
||||
wildcard_models = get_provider_models(provider=provider)
|
||||
|
@ -157,4 +179,4 @@ def get_complete_model_list(
|
|||
for model in models_to_remove:
|
||||
unique_models.remove(model)
|
||||
|
||||
return list(unique_models) + all_wildcard_models
|
||||
return all_wildcard_models
|
||||
|
|
|
@ -3322,6 +3322,7 @@ class ProxyStartupEvent:
|
|||
) # if project requires model list
|
||||
async def model_list(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
return_wildcard_routes: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Use `/model/info` - to get detailed model information, example - pricing, mode, etc.
|
||||
|
@ -3354,6 +3355,7 @@ async def model_list(
|
|||
proxy_model_list=proxy_model_list,
|
||||
user_model=user_model,
|
||||
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||
return_wildcard_routes=return_wildcard_routes,
|
||||
)
|
||||
return dict(
|
||||
data=[
|
||||
|
|
|
@ -23,6 +23,7 @@ import {
|
|||
message,
|
||||
Radio,
|
||||
} from "antd";
|
||||
import { unfurlWildcardModelsInList, getModelDisplayName } from "./key_team_helpers/fetch_available_models_team_key";
|
||||
import {
|
||||
keyCreateCall,
|
||||
slackBudgetAlertsHealthCheck,
|
||||
|
@ -80,8 +81,8 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
|||
const [isModalVisible, setIsModalVisible] = useState(false);
|
||||
const [apiKey, setApiKey] = useState(null);
|
||||
const [softBudget, setSoftBudget] = useState(null);
|
||||
const [userModels, setUserModels] = useState([]);
|
||||
const [modelsToPick, setModelsToPick] = useState([]);
|
||||
const [userModels, setUserModels] = useState<string[]>([]);
|
||||
const [modelsToPick, setModelsToPick] = useState<string[]>([]);
|
||||
const [keyOwner, setKeyOwner] = useState("you");
|
||||
const [predefinedTags, setPredefinedTags] = useState(getPredefinedTags(data));
|
||||
const [guardrailsList, setGuardrailsList] = useState<string[]>([]);
|
||||
|
@ -213,6 +214,8 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
|||
tempModelsToPick = userModels;
|
||||
}
|
||||
|
||||
tempModelsToPick = unfurlWildcardModelsInList(tempModelsToPick, userModels);
|
||||
|
||||
setModelsToPick(tempModelsToPick);
|
||||
}, [team, userModels]);
|
||||
|
||||
|
@ -310,7 +313,7 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
|||
</Option>
|
||||
{modelsToPick.map((model: string) => (
|
||||
<Option key={model} value={model}>
|
||||
{model}
|
||||
{getModelDisplayName(model)}
|
||||
</Option>
|
||||
))}
|
||||
</Select>
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
import { modelAvailableCall } from "../networking";
|
||||
|
||||
|
||||
export const fetchAvailableModelsForTeamOrKey = async (
|
||||
userID: string,
|
||||
userRole: string,
|
||||
accessToken: string,
|
||||
): Promise<string[] | undefined> => {
|
||||
try {
|
||||
if (userID === null || userRole === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (accessToken !== null) {
|
||||
const model_available = await modelAvailableCall(
|
||||
accessToken,
|
||||
userID,
|
||||
userRole,
|
||||
true
|
||||
);
|
||||
|
||||
let available_model_names = model_available["data"].map(
|
||||
(element: { id: string }) => element.id
|
||||
);
|
||||
|
||||
// Group and sort models
|
||||
const providerModels: string[] = [];
|
||||
const specificModels: string[] = [];
|
||||
|
||||
available_model_names.forEach((model: string) => {
|
||||
if (model.endsWith('/*')) {
|
||||
providerModels.push(model);
|
||||
} else {
|
||||
specificModels.push(model);
|
||||
}
|
||||
});
|
||||
|
||||
// Combine arrays with provider models first
|
||||
return [...providerModels, ...specificModels];
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching user models:", error);
|
||||
}
|
||||
};
|
||||
|
||||
export const getModelDisplayName = (model: string) => {
|
||||
console.log("getModelDisplayName", model);
|
||||
if (model.endsWith('/*')) {
|
||||
const provider = model.replace('/*', '');
|
||||
return `All ${provider} models`;
|
||||
}
|
||||
return model;
|
||||
};
|
||||
|
||||
export const unfurlWildcardModelsInList = (teamModels: string[], allModels: string[]): string[] => {
|
||||
const wildcardDisplayNames: string[] = [];
|
||||
const expandedModels: string[] = [];
|
||||
console.log("teamModels", teamModels);
|
||||
console.log("allModels", allModels);
|
||||
|
||||
teamModels.forEach(teamModel => {
|
||||
if (teamModel.endsWith('/*')) {
|
||||
// Extract the provider prefix (e.g., 'openai' from 'openai/*')
|
||||
const provider = teamModel.replace('/*', '');
|
||||
|
||||
// Find all models that start with this provider
|
||||
const matchingModels = allModels.filter(model =>
|
||||
model.startsWith(provider + '/')
|
||||
);
|
||||
expandedModels.push(...matchingModels);
|
||||
wildcardDisplayNames.push(teamModel);
|
||||
}
|
||||
else {
|
||||
expandedModels.push(teamModel);
|
||||
}
|
||||
});
|
||||
|
||||
// Combine arrays with wildcard display names first, then remove duplicates
|
||||
return [...wildcardDisplayNames, ...expandedModels].filter((item, index, array) =>
|
||||
array.indexOf(item) === index
|
||||
);
|
||||
};
|
|
@ -1318,7 +1318,8 @@ export const modelExceptionsCall = async (
|
|||
export const modelAvailableCall = async (
|
||||
accessToken: String,
|
||||
userID: String,
|
||||
userRole: String
|
||||
userRole: String,
|
||||
return_wildcard_routes: boolean = false
|
||||
) => {
|
||||
/**
|
||||
* Get all the models user has access to
|
||||
|
@ -1326,6 +1327,9 @@ export const modelAvailableCall = async (
|
|||
console.log("in /models calls, globalLitellmHeaderName", globalLitellmHeaderName)
|
||||
try {
|
||||
let url = proxyBaseUrl ? `${proxyBaseUrl}/models` : `/models`;
|
||||
if (return_wildcard_routes === true) {
|
||||
url += `?return_wildcard_routes=True`;
|
||||
}
|
||||
|
||||
//message.info("Requesting model data");
|
||||
const response = await fetch(url, {
|
||||
|
|
|
@ -43,7 +43,6 @@ const AvailableTeamsPanel: React.FC<AvailableTeamsProps> = ({
|
|||
setAvailableTeams(response);
|
||||
} catch (error) {
|
||||
console.error('Error fetching available teams:', error);
|
||||
message.error('Failed to load available teams');
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import {
|
|||
message,
|
||||
Tooltip
|
||||
} from "antd";
|
||||
import { fetchAvailableModelsForTeamOrKey, getModelDisplayName } from "./key_team_helpers/fetch_available_models_team_key";
|
||||
import { Select, SelectItem } from "@tremor/react";
|
||||
import { InfoCircleOutlined } from '@ant-design/icons';
|
||||
import { getGuardrailsList } from "./networking";
|
||||
|
@ -134,7 +135,7 @@ const Team: React.FC<TeamProps> = ({
|
|||
const [isTeamModalVisible, setIsTeamModalVisible] = useState(false);
|
||||
const [isAddMemberModalVisible, setIsAddMemberModalVisible] = useState(false);
|
||||
const [isEditMemberModalVisible, setIsEditMemberModalVisible] = useState(false);
|
||||
const [userModels, setUserModels] = useState([]);
|
||||
const [userModels, setUserModels] = useState<string[]>([]);
|
||||
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
|
||||
const [teamToDelete, setTeamToDelete] = useState<string | null>(null);
|
||||
const [selectedEditMember, setSelectedEditMember] = useState<null | TeamMember>(null);
|
||||
|
@ -236,7 +237,7 @@ const Team: React.FC<TeamProps> = ({
|
|||
{userModels &&
|
||||
userModels.map((model) => (
|
||||
<Select2.Option key={model} value={model}>
|
||||
{model}
|
||||
{getModelDisplayName(model)}
|
||||
</Select2.Option>
|
||||
))}
|
||||
</Select2>
|
||||
|
@ -405,21 +406,12 @@ const Team: React.FC<TeamProps> = ({
|
|||
useEffect(() => {
|
||||
const fetchUserModels = async () => {
|
||||
try {
|
||||
if (userID === null || userRole === null) {
|
||||
if (userID === null || userRole === null || accessToken === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (accessToken !== null) {
|
||||
const model_available = await modelAvailableCall(
|
||||
accessToken,
|
||||
userID,
|
||||
userRole
|
||||
);
|
||||
let available_model_names = model_available["data"].map(
|
||||
(element: { id: string }) => element.id
|
||||
);
|
||||
console.log("available_model_names:", available_model_names);
|
||||
setUserModels(available_model_names);
|
||||
const models = await fetchAvailableModelsForTeamOrKey(userID, userRole, accessToken);
|
||||
if (models) {
|
||||
setUserModels(models);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching user models:", error);
|
||||
|
@ -715,8 +707,8 @@ const Team: React.FC<TeamProps> = ({
|
|||
>
|
||||
<Text>
|
||||
{model.length > 30
|
||||
? `${model.slice(0, 30)}...`
|
||||
: model}
|
||||
? `${getModelDisplayName(model).slice(0, 30)}...`
|
||||
: getModelDisplayName(model)}
|
||||
</Text>
|
||||
</Badge>
|
||||
)
|
||||
|
@ -875,7 +867,7 @@ const Team: React.FC<TeamProps> = ({
|
|||
</Select2.Option>
|
||||
{userModels.map((model) => (
|
||||
<Select2.Option key={model} value={model}>
|
||||
{model}
|
||||
{getModelDisplayName(model)}
|
||||
</Select2.Option>
|
||||
))}
|
||||
</Select2>
|
||||
|
|
|
@ -27,6 +27,7 @@ import {
|
|||
Textarea,
|
||||
} from "@tremor/react";
|
||||
import { InfoCircleOutlined } from '@ant-design/icons';
|
||||
import { fetchAvailableModelsForTeamOrKey, getModelDisplayName } from "./key_team_helpers/fetch_available_models_team_key";
|
||||
import { Select as Select3, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react";
|
||||
import {
|
||||
Button as Button2,
|
||||
|
@ -121,7 +122,7 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
|||
const [editModalVisible, setEditModalVisible] = useState(false);
|
||||
const [infoDialogVisible, setInfoDialogVisible] = useState(false);
|
||||
const [selectedToken, setSelectedToken] = useState<ItemData | null>(null);
|
||||
const [userModels, setUserModels] = useState([]);
|
||||
const [userModels, setUserModels] = useState<string[]>([]);
|
||||
const initialKnownTeamIDs: Set<string> = new Set();
|
||||
const [modelLimitModalVisible, setModelLimitModalVisible] = useState(false);
|
||||
const [regenerateDialogVisible, setRegenerateDialogVisible] = useState(false);
|
||||
|
@ -238,17 +239,13 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
|||
useEffect(() => {
|
||||
const fetchUserModels = async () => {
|
||||
try {
|
||||
if (userID === null) {
|
||||
if (userID === null || userRole === null || accessToken === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (accessToken !== null && userRole !== null) {
|
||||
const model_available = await modelAvailableCall(accessToken, userID, userRole);
|
||||
let available_model_names = model_available["data"].map(
|
||||
(element: { id: string }) => element.id
|
||||
);
|
||||
console.log("available_model_names:", available_model_names);
|
||||
setUserModels(available_model_names);
|
||||
const models = await fetchAvailableModelsForTeamOrKey(userID, userRole, accessToken);
|
||||
if (models) {
|
||||
setUserModels(models);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching user models:", error);
|
||||
|
@ -424,20 +421,20 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
|||
keyTeam.models.includes("all-proxy-models") ? (
|
||||
userModels.filter(model => model !== "all-proxy-models").map((model: string) => (
|
||||
<Option key={model} value={model}>
|
||||
{model}
|
||||
{getModelDisplayName(model)}
|
||||
</Option>
|
||||
))
|
||||
) : (
|
||||
keyTeam.models.map((model: string) => (
|
||||
<Option key={model} value={model}>
|
||||
{model}
|
||||
{getModelDisplayName(model)}
|
||||
</Option>
|
||||
))
|
||||
)
|
||||
) : (
|
||||
userModels.map((model: string) => (
|
||||
<Option key={model} value={model}>
|
||||
{model}
|
||||
{getModelDisplayName(model)}
|
||||
</Option>
|
||||
))
|
||||
)}
|
||||
|
@ -1095,7 +1092,7 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
|||
</Badge>
|
||||
) : (
|
||||
<Badge key={index} size={"xs"} className="mb-1" color="blue">
|
||||
<Text>{model.length > 30 ? `${model.slice(0, 30)}...` : model}</Text>
|
||||
<Text>{model.length > 30 ? `${getModelDisplayName(model).slice(0, 30)}...` : getModelDisplayName(model)}</Text>
|
||||
</Badge>
|
||||
)
|
||||
))
|
||||
|
@ -1118,7 +1115,7 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
|||
</Badge>
|
||||
) : (
|
||||
<Badge key={index} size={"xs"} className="mb-1" color="blue">
|
||||
<Text>{model.length > 30 ? `${model.slice(0, 30)}...` : model}</Text>
|
||||
<Text>{model.length > 30 ? `${getModelDisplayName(model).slice(0, 30)}...` : getModelDisplayName(model)}</Text>
|
||||
</Badge>
|
||||
)
|
||||
))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue