(UI) allow adding model aliases for teams (#8471)

* update team info endpoint

* clean up model alias

* fix model alias

* fix model alias card

* clean up naming on docs

* fix model alias card

* fix _model_in_team_aliases

* fix key_model_access_denied

* test_can_key_call_model_with_aliases

* fix test_aview_spend_per_user
This commit is contained in:
Ishaan Jaff 2025-02-11 16:18:43 -08:00 committed by GitHub
parent 89168d9113
commit 425f1b3976
6 changed files with 347 additions and 2 deletions

View file

@ -101,6 +101,7 @@ async def common_checks(
team_object=team_object, team_object=team_object,
model=_model, model=_model,
llm_router=llm_router, llm_router=llm_router,
team_model_aliases=valid_token.team_model_aliases if valid_token else None,
) )
## 2.1 If user can call model (if personal key) ## 2.1 If user can call model (if personal key)
@ -968,6 +969,7 @@ async def _can_object_call_model(
model: str, model: str,
llm_router: Optional[Router], llm_router: Optional[Router],
models: List[str], models: List[str],
team_model_aliases: Optional[Dict[str, str]] = None,
) -> Literal[True]: ) -> Literal[True]:
""" """
Checks if token can call a given model Checks if token can call a given model
@ -1002,6 +1004,9 @@ async def _can_object_call_model(
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
if _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases):
return True
if _model_matches_any_wildcard_pattern_in_list( if _model_matches_any_wildcard_pattern_in_list(
model=model, allowed_model_list=filtered_models model=model, allowed_model_list=filtered_models
): ):
@ -1026,6 +1031,26 @@ async def _can_object_call_model(
return True return True
def _model_in_team_aliases(
model: str, team_model_aliases: Optional[Dict[str, str]] = None
) -> bool:
"""
Returns True if `model` being accessed is an alias of a team model
- `model=gpt-4o`
- `team_model_aliases={"gpt-4o": "gpt-4o-team-1"}`
- returns True
- `model=gp-4o`
- `team_model_aliases={"o-3": "o3-preview"}`
- returns False
"""
if team_model_aliases:
if model in team_model_aliases:
return True
return False
async def can_key_call_model( async def can_key_call_model(
model: str, model: str,
llm_model_list: Optional[list], llm_model_list: Optional[list],
@ -1045,6 +1070,7 @@ async def can_key_call_model(
model=model, model=model,
llm_router=llm_router, llm_router=llm_router,
models=valid_token.models, models=valid_token.models,
team_model_aliases=valid_token.team_model_aliases,
) )
@ -1217,6 +1243,7 @@ def _team_model_access_check(
model: Optional[str], model: Optional[str],
team_object: Optional[LiteLLM_TeamTable], team_object: Optional[LiteLLM_TeamTable],
llm_router: Optional[Router], llm_router: Optional[Router],
team_model_aliases: Optional[Dict[str, str]] = None,
): ):
""" """
Access check for team models Access check for team models
@ -1244,6 +1271,8 @@ def _team_model_access_check(
pass pass
elif model and "*" in model: elif model and "*" in model:
pass pass
elif _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases):
pass
elif _model_matches_any_wildcard_pattern_in_list( elif _model_matches_any_wildcard_pattern_in_list(
model=model, allowed_model_list=team_object.models model=model, allowed_model_list=team_object.models
): ):

View file

@ -1516,7 +1516,11 @@ async def list_team(
detail={"error": CommonProxyErrors.db_not_connected_error.value}, detail={"error": CommonProxyErrors.db_not_connected_error.value},
) )
response = await prisma_client.db.litellm_teamtable.find_many() response = await prisma_client.db.litellm_teamtable.find_many(
include={
"litellm_model_table": True,
}
)
filtered_response = [] filtered_response = []
if user_id: if user_id:

View file

@ -1022,3 +1022,82 @@ async def test_key_generate_always_db_team(mock_get_team_object):
mock_get_team_object.assert_called_once() mock_get_team_object.assert_called_once()
assert mock_get_team_object.call_args.kwargs["check_db_only"] == True assert mock_get_team_object.call_args.kwargs["check_db_only"] == True
@pytest.mark.asyncio
@pytest.mark.parametrize(
"requested_model, should_pass",
[
("gpt-4o", True), # Should pass - exact match in aliases
("gpt-4o-team1", True), # Should pass - team has access to this deployment
("gpt-4o-mini", False), # Should fail - not in aliases
("o-3", False), # Should fail - not in aliases
],
)
async def test_team_model_alias(prisma_client, requested_model, should_pass):
"""
Test team model alias functionality:
1. Create team with model alias = `{gpt-4o: gpt-4o-team1}`
2. Generate key for that team with model = `gpt-4o`
3. Verify chat completion request works with aliased model = `gpt-4o`
"""
litellm.set_verbose = True
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
# Create team with model alias
team_id = f"test_team_{uuid.uuid4()}"
await new_team(
data=NewTeamRequest(
team_id=team_id,
team_alias=f"test_team_alias_{uuid.uuid4()}",
models=["gpt-4o-team1"],
model_aliases={"gpt-4o": "gpt-4o-team1"},
),
http_request=Request(scope={"type": "http"}),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
),
)
# Generate key for the team
new_key = await generate_key_fn(
data=GenerateKeyRequest(
team_id=team_id,
models=["gpt-4o-team1"],
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
),
)
generated_key = new_key.key
# Test chat completion request
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
async def return_body():
return_string = f'{{"model": "{requested_model}"}}'
return return_string.encode()
request.body = return_body
if should_pass:
# Verify the key works with the aliased model
result = await user_api_key_auth(
request=request, api_key=f"Bearer {generated_key}"
)
assert result.models == [
"gpt-4o-team1"
], "Expected model list to contain aliased model"
assert result.team_model_aliases == {
"gpt-4o": "gpt-4o-team1"
}, "Expected model aliases to be present"
else:
# Verify the key fails with non-aliased models
with pytest.raises(Exception) as exc_info:
await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}")
assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied

View file

@ -633,3 +633,52 @@ async def test_get_fuzzy_user_object():
mock_prisma.db.litellm_usertable.find_unique.assert_called_with( mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True} where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
) )
@pytest.mark.parametrize(
"model, alias_map, expect_to_work",
[
("gpt-4", {"gpt-4": "gpt-4-team1"}, True), # model matches alias value
("gpt-5", {"gpt-4": "gpt-4-team1"}, False),
],
)
@pytest.mark.asyncio
async def test_can_key_call_model_with_aliases(model, alias_map, expect_to_work):
"""
Test if can_key_call_model correctly handles model aliases in the token
"""
from litellm.proxy.auth.auth_checks import can_key_call_model
llm_model_list = [
{
"model_name": "gpt-4-team1",
"litellm_params": {
"model": "gpt-4",
"api_key": "test-api-key",
},
}
]
router = litellm.Router(model_list=llm_model_list)
user_api_key_object = UserAPIKeyAuth(
models=[
"gpt-4-team1",
],
team_model_aliases=alias_map,
)
if expect_to_work:
await can_key_call_model(
model=model,
llm_model_list=llm_model_list,
valid_token=user_api_key_object,
llm_router=router,
)
else:
with pytest.raises(Exception) as e:
await can_key_call_model(
model=model,
llm_model_list=llm_model_list,
valid_token=user_api_key_object,
llm_router=router,
)

View file

@ -0,0 +1,174 @@
import React, { useState } from "react";
import {
Card,
Title,
Text,
Button as TremorButton,
} from "@tremor/react";
import { Modal, Form, Select, Input, message } from "antd";
import { teamUpdateCall } from "@/components/networking";
interface ModelAliasesCardProps {
teamId: string;
accessToken: string | null;
currentAliases: Record<string, string>;
availableModels: string[];
onUpdate: () => void;
}
const ModelAliasesCard: React.FC<ModelAliasesCardProps> = ({
teamId,
accessToken,
currentAliases,
availableModels,
onUpdate,
}) => {
const [isModalVisible, setIsModalVisible] = useState(false);
const [form] = Form.useForm();
const handleCreateAlias = async (values: any) => {
try {
if (!accessToken) return;
const newAliases = {
...currentAliases,
[values.alias_name]: values.original_model,
};
const updateData = {
team_id: teamId,
model_aliases: newAliases,
};
await teamUpdateCall(accessToken, updateData);
message.success("Model alias created successfully");
setIsModalVisible(false);
form.resetFields();
currentAliases[values.alias_name] = values.original_model;
} catch (error) {
message.error("Failed to create model alias");
console.error("Error creating model alias:", error);
}
};
return (
<div className="mt-8">
<Title>Team Aliases</Title>
<Text className="text-gray-600 mb-4">
Allow a team to use an alias that points to a specific model deployment.
</Text>
<div className="bg-white rounded-lg p-6 border border-gray-200">
<div className="flex justify-between items-center mb-6">
<div>
<div className="flex space-x-4 text-gray-600">
<div className="w-64">ALIAS</div>
<div>POINTS TO</div>
</div>
</div>
<TremorButton
size="md"
variant="primary"
onClick={() => setIsModalVisible(true)}
>
Create Model Alias
</TremorButton>
</div>
<div className="space-y-4">
{Object.entries(currentAliases).map(([aliasName, originalModel], index) => (
<div key={index} className="flex space-x-4 border-t border-gray-200 pt-4">
<div className="w-64">
<span className="bg-gray-100 px-2 py-1 rounded font-mono text-sm text-gray-700">
{aliasName}
</span>
</div>
<div>
<span className="bg-gray-100 px-2 py-1 rounded font-mono text-sm text-gray-700">
{originalModel}
</span>
</div>
</div>
))}
{Object.keys(currentAliases).length === 0 && (
<div className="text-gray-500 text-center py-4 border-t border-gray-200">
No model aliases configured
</div>
)}
</div>
</div>
<Modal
title="Create Model Alias"
open={isModalVisible}
onCancel={() => {
setIsModalVisible(false);
form.resetFields();
}}
footer={null}
width={500}
>
<Form
form={form}
onFinish={handleCreateAlias}
layout="vertical"
className="mt-4"
>
<Form.Item
label="Alias Name"
name="alias_name"
rules={[{ required: true, message: "Please enter an alias name" }]}
>
<Input
placeholder="Enter the model alias (e.g., gpt-4o)"
type=""
/>
</Form.Item>
<Form.Item
label="Points To"
name="original_model"
rules={[{ required: true, message: "Please select a model" }]}
>
<Select
placeholder="Select model version"
className="w-full font-mono"
showSearch
optionFilterProp="children"
>
{availableModels.map((model) => (
<Select.Option key={model} value={model} className="font-mono">
{model}
</Select.Option>
))}
</Select>
</Form.Item>
<div className="flex justify-end gap-2 mt-6">
<TremorButton
size="md"
variant="secondary"
onClick={() => {
setIsModalVisible(false);
form.resetFields();
}}
className="bg-white text-gray-700 border border-gray-300 hover:bg-gray-50"
>
Cancel
</TremorButton>
<TremorButton
size="md"
variant="secondary"
type="submit"
>
Create Alias
</TremorButton>
</div>
</Form>
</Modal>
</div>
);
};
export default ModelAliasesCard;

View file

@ -29,6 +29,7 @@ import { PencilAltIcon, PlusIcon, TrashIcon } from "@heroicons/react/outline";
import TeamMemberModal from "./edit_membership"; import TeamMemberModal from "./edit_membership";
import UserSearchModal from "@/components/common_components/user_search_modal"; import UserSearchModal from "@/components/common_components/user_search_modal";
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key"; import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
import ModelAliasesCard from "./model_aliases_card";
interface TeamData { interface TeamData {
@ -586,7 +587,7 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
/> />
</Form.Item> </Form.Item>
<div className="flex justify-end gap-2"> <div className="flex justify-end gap-2 mt-6">
<Button onClick={() => setIsEditing(false)}> <Button onClick={() => setIsEditing(false)}>
Cancel Cancel
</Button> </Button>
@ -638,6 +639,15 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
</div> </div>
)} )}
</Card> </Card>
{/* Add Model Aliases Card */}
<ModelAliasesCard
teamId={teamId}
accessToken={accessToken}
currentAliases={teamData?.team_info?.litellm_model_table?.model_aliases || {}}
availableModels={userModels}
onUpdate={fetchTeamInfo}
/>
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>
</TabGroup> </TabGroup>