mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
(UI) allow adding model aliases for teams (#8471)
* update team info endpoint * clean up model alias * fix model alias * fix model alias card * clean up naming on docs * fix model alias card * fix _model_in_team_aliases * fix key_model_access_denied * test_can_key_call_model_with_aliases * fix test_aview_spend_per_user
This commit is contained in:
parent
89168d9113
commit
425f1b3976
6 changed files with 347 additions and 2 deletions
|
@ -101,6 +101,7 @@ async def common_checks(
|
||||||
team_object=team_object,
|
team_object=team_object,
|
||||||
model=_model,
|
model=_model,
|
||||||
llm_router=llm_router,
|
llm_router=llm_router,
|
||||||
|
team_model_aliases=valid_token.team_model_aliases if valid_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
## 2.1 If user can call model (if personal key)
|
## 2.1 If user can call model (if personal key)
|
||||||
|
@ -968,6 +969,7 @@ async def _can_object_call_model(
|
||||||
model: str,
|
model: str,
|
||||||
llm_router: Optional[Router],
|
llm_router: Optional[Router],
|
||||||
models: List[str],
|
models: List[str],
|
||||||
|
team_model_aliases: Optional[Dict[str, str]] = None,
|
||||||
) -> Literal[True]:
|
) -> Literal[True]:
|
||||||
"""
|
"""
|
||||||
Checks if token can call a given model
|
Checks if token can call a given model
|
||||||
|
@ -1002,6 +1004,9 @@ async def _can_object_call_model(
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
|
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
|
||||||
|
|
||||||
|
if _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases):
|
||||||
|
return True
|
||||||
|
|
||||||
if _model_matches_any_wildcard_pattern_in_list(
|
if _model_matches_any_wildcard_pattern_in_list(
|
||||||
model=model, allowed_model_list=filtered_models
|
model=model, allowed_model_list=filtered_models
|
||||||
):
|
):
|
||||||
|
@ -1026,6 +1031,26 @@ async def _can_object_call_model(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _model_in_team_aliases(
|
||||||
|
model: str, team_model_aliases: Optional[Dict[str, str]] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if `model` being accessed is an alias of a team model
|
||||||
|
|
||||||
|
- `model=gpt-4o`
|
||||||
|
- `team_model_aliases={"gpt-4o": "gpt-4o-team-1"}`
|
||||||
|
- returns True
|
||||||
|
|
||||||
|
- `model=gp-4o`
|
||||||
|
- `team_model_aliases={"o-3": "o3-preview"}`
|
||||||
|
- returns False
|
||||||
|
"""
|
||||||
|
if team_model_aliases:
|
||||||
|
if model in team_model_aliases:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def can_key_call_model(
|
async def can_key_call_model(
|
||||||
model: str,
|
model: str,
|
||||||
llm_model_list: Optional[list],
|
llm_model_list: Optional[list],
|
||||||
|
@ -1045,6 +1070,7 @@ async def can_key_call_model(
|
||||||
model=model,
|
model=model,
|
||||||
llm_router=llm_router,
|
llm_router=llm_router,
|
||||||
models=valid_token.models,
|
models=valid_token.models,
|
||||||
|
team_model_aliases=valid_token.team_model_aliases,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1217,6 +1243,7 @@ def _team_model_access_check(
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
team_object: Optional[LiteLLM_TeamTable],
|
team_object: Optional[LiteLLM_TeamTable],
|
||||||
llm_router: Optional[Router],
|
llm_router: Optional[Router],
|
||||||
|
team_model_aliases: Optional[Dict[str, str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Access check for team models
|
Access check for team models
|
||||||
|
@ -1244,6 +1271,8 @@ def _team_model_access_check(
|
||||||
pass
|
pass
|
||||||
elif model and "*" in model:
|
elif model and "*" in model:
|
||||||
pass
|
pass
|
||||||
|
elif _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases):
|
||||||
|
pass
|
||||||
elif _model_matches_any_wildcard_pattern_in_list(
|
elif _model_matches_any_wildcard_pattern_in_list(
|
||||||
model=model, allowed_model_list=team_object.models
|
model=model, allowed_model_list=team_object.models
|
||||||
):
|
):
|
||||||
|
|
|
@ -1516,7 +1516,11 @@ async def list_team(
|
||||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await prisma_client.db.litellm_teamtable.find_many()
|
response = await prisma_client.db.litellm_teamtable.find_many(
|
||||||
|
include={
|
||||||
|
"litellm_model_table": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
filtered_response = []
|
filtered_response = []
|
||||||
if user_id:
|
if user_id:
|
||||||
|
|
|
@ -1022,3 +1022,82 @@ async def test_key_generate_always_db_team(mock_get_team_object):
|
||||||
|
|
||||||
mock_get_team_object.assert_called_once()
|
mock_get_team_object.assert_called_once()
|
||||||
assert mock_get_team_object.call_args.kwargs["check_db_only"] == True
|
assert mock_get_team_object.call_args.kwargs["check_db_only"] == True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"requested_model, should_pass",
|
||||||
|
[
|
||||||
|
("gpt-4o", True), # Should pass - exact match in aliases
|
||||||
|
("gpt-4o-team1", True), # Should pass - team has access to this deployment
|
||||||
|
("gpt-4o-mini", False), # Should fail - not in aliases
|
||||||
|
("o-3", False), # Should fail - not in aliases
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_team_model_alias(prisma_client, requested_model, should_pass):
|
||||||
|
"""
|
||||||
|
Test team model alias functionality:
|
||||||
|
1. Create team with model alias = `{gpt-4o: gpt-4o-team1}`
|
||||||
|
2. Generate key for that team with model = `gpt-4o`
|
||||||
|
3. Verify chat completion request works with aliased model = `gpt-4o`
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
# Create team with model alias
|
||||||
|
team_id = f"test_team_{uuid.uuid4()}"
|
||||||
|
await new_team(
|
||||||
|
data=NewTeamRequest(
|
||||||
|
team_id=team_id,
|
||||||
|
team_alias=f"test_team_alias_{uuid.uuid4()}",
|
||||||
|
models=["gpt-4o-team1"],
|
||||||
|
model_aliases={"gpt-4o": "gpt-4o-team1"},
|
||||||
|
),
|
||||||
|
http_request=Request(scope={"type": "http"}),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate key for the team
|
||||||
|
new_key = await generate_key_fn(
|
||||||
|
data=GenerateKeyRequest(
|
||||||
|
team_id=team_id,
|
||||||
|
models=["gpt-4o-team1"],
|
||||||
|
),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_key = new_key.key
|
||||||
|
|
||||||
|
# Test chat completion request
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
async def return_body():
|
||||||
|
return_string = f'{{"model": "{requested_model}"}}'
|
||||||
|
return return_string.encode()
|
||||||
|
|
||||||
|
request.body = return_body
|
||||||
|
|
||||||
|
if should_pass:
|
||||||
|
# Verify the key works with the aliased model
|
||||||
|
result = await user_api_key_auth(
|
||||||
|
request=request, api_key=f"Bearer {generated_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.models == [
|
||||||
|
"gpt-4o-team1"
|
||||||
|
], "Expected model list to contain aliased model"
|
||||||
|
assert result.team_model_aliases == {
|
||||||
|
"gpt-4o": "gpt-4o-team1"
|
||||||
|
}, "Expected model aliases to be present"
|
||||||
|
else:
|
||||||
|
# Verify the key fails with non-aliased models
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}")
|
||||||
|
assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied
|
||||||
|
|
|
@ -633,3 +633,52 @@ async def test_get_fuzzy_user_object():
|
||||||
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
|
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
|
||||||
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
|
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, alias_map, expect_to_work",
|
||||||
|
[
|
||||||
|
("gpt-4", {"gpt-4": "gpt-4-team1"}, True), # model matches alias value
|
||||||
|
("gpt-5", {"gpt-4": "gpt-4-team1"}, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_can_key_call_model_with_aliases(model, alias_map, expect_to_work):
|
||||||
|
"""
|
||||||
|
Test if can_key_call_model correctly handles model aliases in the token
|
||||||
|
"""
|
||||||
|
from litellm.proxy.auth.auth_checks import can_key_call_model
|
||||||
|
|
||||||
|
llm_model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4-team1",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = litellm.Router(model_list=llm_model_list)
|
||||||
|
|
||||||
|
user_api_key_object = UserAPIKeyAuth(
|
||||||
|
models=[
|
||||||
|
"gpt-4-team1",
|
||||||
|
],
|
||||||
|
team_model_aliases=alias_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
if expect_to_work:
|
||||||
|
await can_key_call_model(
|
||||||
|
model=model,
|
||||||
|
llm_model_list=llm_model_list,
|
||||||
|
valid_token=user_api_key_object,
|
||||||
|
llm_router=router,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
await can_key_call_model(
|
||||||
|
model=model,
|
||||||
|
llm_model_list=llm_model_list,
|
||||||
|
valid_token=user_api_key_object,
|
||||||
|
llm_router=router,
|
||||||
|
)
|
||||||
|
|
174
ui/litellm-dashboard/src/components/team/model_aliases_card.tsx
Normal file
174
ui/litellm-dashboard/src/components/team/model_aliases_card.tsx
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
import React, { useState } from "react";
|
||||||
|
import {
|
||||||
|
Card,
|
||||||
|
Title,
|
||||||
|
Text,
|
||||||
|
Button as TremorButton,
|
||||||
|
} from "@tremor/react";
|
||||||
|
import { Modal, Form, Select, Input, message } from "antd";
|
||||||
|
import { teamUpdateCall } from "@/components/networking";
|
||||||
|
|
||||||
|
interface ModelAliasesCardProps {
|
||||||
|
teamId: string;
|
||||||
|
accessToken: string | null;
|
||||||
|
currentAliases: Record<string, string>;
|
||||||
|
availableModels: string[];
|
||||||
|
onUpdate: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ModelAliasesCard: React.FC<ModelAliasesCardProps> = ({
|
||||||
|
teamId,
|
||||||
|
accessToken,
|
||||||
|
currentAliases,
|
||||||
|
availableModels,
|
||||||
|
onUpdate,
|
||||||
|
}) => {
|
||||||
|
const [isModalVisible, setIsModalVisible] = useState(false);
|
||||||
|
const [form] = Form.useForm();
|
||||||
|
|
||||||
|
const handleCreateAlias = async (values: any) => {
|
||||||
|
try {
|
||||||
|
if (!accessToken) return;
|
||||||
|
|
||||||
|
const newAliases = {
|
||||||
|
...currentAliases,
|
||||||
|
[values.alias_name]: values.original_model,
|
||||||
|
};
|
||||||
|
|
||||||
|
const updateData = {
|
||||||
|
team_id: teamId,
|
||||||
|
model_aliases: newAliases,
|
||||||
|
};
|
||||||
|
|
||||||
|
await teamUpdateCall(accessToken, updateData);
|
||||||
|
message.success("Model alias created successfully");
|
||||||
|
setIsModalVisible(false);
|
||||||
|
form.resetFields();
|
||||||
|
currentAliases[values.alias_name] = values.original_model;
|
||||||
|
} catch (error) {
|
||||||
|
message.error("Failed to create model alias");
|
||||||
|
console.error("Error creating model alias:", error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="mt-8">
|
||||||
|
<Title>Team Aliases</Title>
|
||||||
|
<Text className="text-gray-600 mb-4">
|
||||||
|
Allow a team to use an alias that points to a specific model deployment.
|
||||||
|
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<div className="bg-white rounded-lg p-6 border border-gray-200">
|
||||||
|
<div className="flex justify-between items-center mb-6">
|
||||||
|
<div>
|
||||||
|
<div className="flex space-x-4 text-gray-600">
|
||||||
|
<div className="w-64">ALIAS</div>
|
||||||
|
<div>POINTS TO</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<TremorButton
|
||||||
|
size="md"
|
||||||
|
variant="primary"
|
||||||
|
onClick={() => setIsModalVisible(true)}
|
||||||
|
>
|
||||||
|
Create Model Alias
|
||||||
|
</TremorButton>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-4">
|
||||||
|
{Object.entries(currentAliases).map(([aliasName, originalModel], index) => (
|
||||||
|
<div key={index} className="flex space-x-4 border-t border-gray-200 pt-4">
|
||||||
|
<div className="w-64">
|
||||||
|
<span className="bg-gray-100 px-2 py-1 rounded font-mono text-sm text-gray-700">
|
||||||
|
{aliasName}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="bg-gray-100 px-2 py-1 rounded font-mono text-sm text-gray-700">
|
||||||
|
{originalModel}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
{Object.keys(currentAliases).length === 0 && (
|
||||||
|
<div className="text-gray-500 text-center py-4 border-t border-gray-200">
|
||||||
|
No model aliases configured
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Modal
|
||||||
|
title="Create Model Alias"
|
||||||
|
open={isModalVisible}
|
||||||
|
onCancel={() => {
|
||||||
|
setIsModalVisible(false);
|
||||||
|
form.resetFields();
|
||||||
|
}}
|
||||||
|
footer={null}
|
||||||
|
width={500}
|
||||||
|
>
|
||||||
|
<Form
|
||||||
|
form={form}
|
||||||
|
onFinish={handleCreateAlias}
|
||||||
|
layout="vertical"
|
||||||
|
className="mt-4"
|
||||||
|
>
|
||||||
|
<Form.Item
|
||||||
|
label="Alias Name"
|
||||||
|
name="alias_name"
|
||||||
|
rules={[{ required: true, message: "Please enter an alias name" }]}
|
||||||
|
>
|
||||||
|
<Input
|
||||||
|
placeholder="Enter the model alias (e.g., gpt-4o)"
|
||||||
|
type=""
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label="Points To"
|
||||||
|
name="original_model"
|
||||||
|
rules={[{ required: true, message: "Please select a model" }]}
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
placeholder="Select model version"
|
||||||
|
className="w-full font-mono"
|
||||||
|
showSearch
|
||||||
|
optionFilterProp="children"
|
||||||
|
>
|
||||||
|
{availableModels.map((model) => (
|
||||||
|
<Select.Option key={model} value={model} className="font-mono">
|
||||||
|
{model}
|
||||||
|
</Select.Option>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<div className="flex justify-end gap-2 mt-6">
|
||||||
|
<TremorButton
|
||||||
|
size="md"
|
||||||
|
variant="secondary"
|
||||||
|
onClick={() => {
|
||||||
|
setIsModalVisible(false);
|
||||||
|
form.resetFields();
|
||||||
|
}}
|
||||||
|
className="bg-white text-gray-700 border border-gray-300 hover:bg-gray-50"
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</TremorButton>
|
||||||
|
<TremorButton
|
||||||
|
size="md"
|
||||||
|
variant="secondary"
|
||||||
|
type="submit"
|
||||||
|
>
|
||||||
|
Create Alias
|
||||||
|
</TremorButton>
|
||||||
|
</div>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ModelAliasesCard;
|
|
@ -29,6 +29,7 @@ import { PencilAltIcon, PlusIcon, TrashIcon } from "@heroicons/react/outline";
|
||||||
import TeamMemberModal from "./edit_membership";
|
import TeamMemberModal from "./edit_membership";
|
||||||
import UserSearchModal from "@/components/common_components/user_search_modal";
|
import UserSearchModal from "@/components/common_components/user_search_modal";
|
||||||
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
|
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
|
||||||
|
import ModelAliasesCard from "./model_aliases_card";
|
||||||
|
|
||||||
|
|
||||||
interface TeamData {
|
interface TeamData {
|
||||||
|
@ -586,7 +587,7 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
<div className="flex justify-end gap-2">
|
<div className="flex justify-end gap-2 mt-6">
|
||||||
<Button onClick={() => setIsEditing(false)}>
|
<Button onClick={() => setIsEditing(false)}>
|
||||||
Cancel
|
Cancel
|
||||||
</Button>
|
</Button>
|
||||||
|
@ -638,6 +639,15 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
|
{/* Add Model Aliases Card */}
|
||||||
|
<ModelAliasesCard
|
||||||
|
teamId={teamId}
|
||||||
|
accessToken={accessToken}
|
||||||
|
currentAliases={teamData?.team_info?.litellm_model_table?.model_aliases || {}}
|
||||||
|
availableModels={userModels}
|
||||||
|
onUpdate={fetchTeamInfo}
|
||||||
|
/>
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
</TabPanels>
|
</TabPanels>
|
||||||
</TabGroup>
|
</TabGroup>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue