mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
(UI) allow adding model aliases for teams (#8471)
* update team info endpoint * clean up model alias * fix model alias * fix model alias card * clean up naming on docs * fix model alias card * fix _model_in_team_aliases * fix key_model_access_denied * test_can_key_call_model_with_aliases * fix test_aview_spend_per_user
This commit is contained in:
parent
89168d9113
commit
425f1b3976
6 changed files with 347 additions and 2 deletions
|
@ -101,6 +101,7 @@ async def common_checks(
|
|||
team_object=team_object,
|
||||
model=_model,
|
||||
llm_router=llm_router,
|
||||
team_model_aliases=valid_token.team_model_aliases if valid_token else None,
|
||||
)
|
||||
|
||||
## 2.1 If user can call model (if personal key)
|
||||
|
@ -968,6 +969,7 @@ async def _can_object_call_model(
|
|||
model: str,
|
||||
llm_router: Optional[Router],
|
||||
models: List[str],
|
||||
team_model_aliases: Optional[Dict[str, str]] = None,
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
Checks if token can call a given model
|
||||
|
@ -1002,6 +1004,9 @@ async def _can_object_call_model(
|
|||
|
||||
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
|
||||
|
||||
if _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases):
|
||||
return True
|
||||
|
||||
if _model_matches_any_wildcard_pattern_in_list(
|
||||
model=model, allowed_model_list=filtered_models
|
||||
):
|
||||
|
@ -1026,6 +1031,26 @@ async def _can_object_call_model(
|
|||
return True
|
||||
|
||||
|
||||
def _model_in_team_aliases(
|
||||
model: str, team_model_aliases: Optional[Dict[str, str]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if `model` being accessed is an alias of a team model
|
||||
|
||||
- `model=gpt-4o`
|
||||
- `team_model_aliases={"gpt-4o": "gpt-4o-team-1"}`
|
||||
- returns True
|
||||
|
||||
- `model=gp-4o`
|
||||
- `team_model_aliases={"o-3": "o3-preview"}`
|
||||
- returns False
|
||||
"""
|
||||
if team_model_aliases:
|
||||
if model in team_model_aliases:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def can_key_call_model(
|
||||
model: str,
|
||||
llm_model_list: Optional[list],
|
||||
|
@ -1045,6 +1070,7 @@ async def can_key_call_model(
|
|||
model=model,
|
||||
llm_router=llm_router,
|
||||
models=valid_token.models,
|
||||
team_model_aliases=valid_token.team_model_aliases,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1217,6 +1243,7 @@ def _team_model_access_check(
|
|||
model: Optional[str],
|
||||
team_object: Optional[LiteLLM_TeamTable],
|
||||
llm_router: Optional[Router],
|
||||
team_model_aliases: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Access check for team models
|
||||
|
@ -1244,6 +1271,8 @@ def _team_model_access_check(
|
|||
pass
|
||||
elif model and "*" in model:
|
||||
pass
|
||||
elif _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases):
|
||||
pass
|
||||
elif _model_matches_any_wildcard_pattern_in_list(
|
||||
model=model, allowed_model_list=team_object.models
|
||||
):
|
||||
|
|
|
@ -1516,7 +1516,11 @@ async def list_team(
|
|||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_teamtable.find_many()
|
||||
response = await prisma_client.db.litellm_teamtable.find_many(
|
||||
include={
|
||||
"litellm_model_table": True,
|
||||
}
|
||||
)
|
||||
|
||||
filtered_response = []
|
||||
if user_id:
|
||||
|
|
|
@ -1022,3 +1022,82 @@ async def test_key_generate_always_db_team(mock_get_team_object):
|
|||
|
||||
mock_get_team_object.assert_called_once()
|
||||
assert mock_get_team_object.call_args.kwargs["check_db_only"] == True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"requested_model, should_pass",
|
||||
[
|
||||
("gpt-4o", True), # Should pass - exact match in aliases
|
||||
("gpt-4o-team1", True), # Should pass - team has access to this deployment
|
||||
("gpt-4o-mini", False), # Should fail - not in aliases
|
||||
("o-3", False), # Should fail - not in aliases
|
||||
],
|
||||
)
|
||||
async def test_team_model_alias(prisma_client, requested_model, should_pass):
|
||||
"""
|
||||
Test team model alias functionality:
|
||||
1. Create team with model alias = `{gpt-4o: gpt-4o-team1}`
|
||||
2. Generate key for that team with model = `gpt-4o`
|
||||
3. Verify chat completion request works with aliased model = `gpt-4o`
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
# Create team with model alias
|
||||
team_id = f"test_team_{uuid.uuid4()}"
|
||||
await new_team(
|
||||
data=NewTeamRequest(
|
||||
team_id=team_id,
|
||||
team_alias=f"test_team_alias_{uuid.uuid4()}",
|
||||
models=["gpt-4o-team1"],
|
||||
model_aliases={"gpt-4o": "gpt-4o-team1"},
|
||||
),
|
||||
http_request=Request(scope={"type": "http"}),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
||||
),
|
||||
)
|
||||
|
||||
# Generate key for the team
|
||||
new_key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
team_id=team_id,
|
||||
models=["gpt-4o-team1"],
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
||||
),
|
||||
)
|
||||
|
||||
generated_key = new_key.key
|
||||
|
||||
# Test chat completion request
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
async def return_body():
|
||||
return_string = f'{{"model": "{requested_model}"}}'
|
||||
return return_string.encode()
|
||||
|
||||
request.body = return_body
|
||||
|
||||
if should_pass:
|
||||
# Verify the key works with the aliased model
|
||||
result = await user_api_key_auth(
|
||||
request=request, api_key=f"Bearer {generated_key}"
|
||||
)
|
||||
|
||||
assert result.models == [
|
||||
"gpt-4o-team1"
|
||||
], "Expected model list to contain aliased model"
|
||||
assert result.team_model_aliases == {
|
||||
"gpt-4o": "gpt-4o-team1"
|
||||
}, "Expected model aliases to be present"
|
||||
else:
|
||||
# Verify the key fails with non-aliased models
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}")
|
||||
assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied
|
||||
|
|
|
@ -633,3 +633,52 @@ async def test_get_fuzzy_user_object():
|
|||
mock_prisma.db.litellm_usertable.find_unique.assert_called_with(
|
||||
where={"sso_user_id": "sso_123"}, include={"organization_memberships": True}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, alias_map, expect_to_work",
|
||||
[
|
||||
("gpt-4", {"gpt-4": "gpt-4-team1"}, True), # model matches alias value
|
||||
("gpt-5", {"gpt-4": "gpt-4-team1"}, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_key_call_model_with_aliases(model, alias_map, expect_to_work):
|
||||
"""
|
||||
Test if can_key_call_model correctly handles model aliases in the token
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import can_key_call_model
|
||||
|
||||
llm_model_list = [
|
||||
{
|
||||
"model_name": "gpt-4-team1",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
}
|
||||
]
|
||||
router = litellm.Router(model_list=llm_model_list)
|
||||
|
||||
user_api_key_object = UserAPIKeyAuth(
|
||||
models=[
|
||||
"gpt-4-team1",
|
||||
],
|
||||
team_model_aliases=alias_map,
|
||||
)
|
||||
|
||||
if expect_to_work:
|
||||
await can_key_call_model(
|
||||
model=model,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=user_api_key_object,
|
||||
llm_router=router,
|
||||
)
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
await can_key_call_model(
|
||||
model=model,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=user_api_key_object,
|
||||
llm_router=router,
|
||||
)
|
||||
|
|
174
ui/litellm-dashboard/src/components/team/model_aliases_card.tsx
Normal file
174
ui/litellm-dashboard/src/components/team/model_aliases_card.tsx
Normal file
|
@ -0,0 +1,174 @@
|
|||
import React, { useState } from "react";
|
||||
import {
|
||||
Card,
|
||||
Title,
|
||||
Text,
|
||||
Button as TremorButton,
|
||||
} from "@tremor/react";
|
||||
import { Modal, Form, Select, Input, message } from "antd";
|
||||
import { teamUpdateCall } from "@/components/networking";
|
||||
|
||||
interface ModelAliasesCardProps {
|
||||
teamId: string;
|
||||
accessToken: string | null;
|
||||
currentAliases: Record<string, string>;
|
||||
availableModels: string[];
|
||||
onUpdate: () => void;
|
||||
}
|
||||
|
||||
const ModelAliasesCard: React.FC<ModelAliasesCardProps> = ({
|
||||
teamId,
|
||||
accessToken,
|
||||
currentAliases,
|
||||
availableModels,
|
||||
onUpdate,
|
||||
}) => {
|
||||
const [isModalVisible, setIsModalVisible] = useState(false);
|
||||
const [form] = Form.useForm();
|
||||
|
||||
const handleCreateAlias = async (values: any) => {
|
||||
try {
|
||||
if (!accessToken) return;
|
||||
|
||||
const newAliases = {
|
||||
...currentAliases,
|
||||
[values.alias_name]: values.original_model,
|
||||
};
|
||||
|
||||
const updateData = {
|
||||
team_id: teamId,
|
||||
model_aliases: newAliases,
|
||||
};
|
||||
|
||||
await teamUpdateCall(accessToken, updateData);
|
||||
message.success("Model alias created successfully");
|
||||
setIsModalVisible(false);
|
||||
form.resetFields();
|
||||
currentAliases[values.alias_name] = values.original_model;
|
||||
} catch (error) {
|
||||
message.error("Failed to create model alias");
|
||||
console.error("Error creating model alias:", error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mt-8">
|
||||
<Title>Team Aliases</Title>
|
||||
<Text className="text-gray-600 mb-4">
|
||||
Allow a team to use an alias that points to a specific model deployment.
|
||||
|
||||
</Text>
|
||||
|
||||
<div className="bg-white rounded-lg p-6 border border-gray-200">
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
<div>
|
||||
<div className="flex space-x-4 text-gray-600">
|
||||
<div className="w-64">ALIAS</div>
|
||||
<div>POINTS TO</div>
|
||||
</div>
|
||||
</div>
|
||||
<TremorButton
|
||||
size="md"
|
||||
variant="primary"
|
||||
onClick={() => setIsModalVisible(true)}
|
||||
>
|
||||
Create Model Alias
|
||||
</TremorButton>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
{Object.entries(currentAliases).map(([aliasName, originalModel], index) => (
|
||||
<div key={index} className="flex space-x-4 border-t border-gray-200 pt-4">
|
||||
<div className="w-64">
|
||||
<span className="bg-gray-100 px-2 py-1 rounded font-mono text-sm text-gray-700">
|
||||
{aliasName}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="bg-gray-100 px-2 py-1 rounded font-mono text-sm text-gray-700">
|
||||
{originalModel}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
{Object.keys(currentAliases).length === 0 && (
|
||||
<div className="text-gray-500 text-center py-4 border-t border-gray-200">
|
||||
No model aliases configured
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Modal
|
||||
title="Create Model Alias"
|
||||
open={isModalVisible}
|
||||
onCancel={() => {
|
||||
setIsModalVisible(false);
|
||||
form.resetFields();
|
||||
}}
|
||||
footer={null}
|
||||
width={500}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleCreateAlias}
|
||||
layout="vertical"
|
||||
className="mt-4"
|
||||
>
|
||||
<Form.Item
|
||||
label="Alias Name"
|
||||
name="alias_name"
|
||||
rules={[{ required: true, message: "Please enter an alias name" }]}
|
||||
>
|
||||
<Input
|
||||
placeholder="Enter the model alias (e.g., gpt-4o)"
|
||||
type=""
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label="Points To"
|
||||
name="original_model"
|
||||
rules={[{ required: true, message: "Please select a model" }]}
|
||||
>
|
||||
<Select
|
||||
placeholder="Select model version"
|
||||
className="w-full font-mono"
|
||||
showSearch
|
||||
optionFilterProp="children"
|
||||
>
|
||||
{availableModels.map((model) => (
|
||||
<Select.Option key={model} value={model} className="font-mono">
|
||||
{model}
|
||||
</Select.Option>
|
||||
))}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
|
||||
<div className="flex justify-end gap-2 mt-6">
|
||||
<TremorButton
|
||||
size="md"
|
||||
variant="secondary"
|
||||
onClick={() => {
|
||||
setIsModalVisible(false);
|
||||
form.resetFields();
|
||||
}}
|
||||
className="bg-white text-gray-700 border border-gray-300 hover:bg-gray-50"
|
||||
>
|
||||
Cancel
|
||||
</TremorButton>
|
||||
<TremorButton
|
||||
size="md"
|
||||
variant="secondary"
|
||||
type="submit"
|
||||
>
|
||||
Create Alias
|
||||
</TremorButton>
|
||||
</div>
|
||||
</Form>
|
||||
</Modal>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ModelAliasesCard;
|
|
@ -29,6 +29,7 @@ import { PencilAltIcon, PlusIcon, TrashIcon } from "@heroicons/react/outline";
|
|||
import TeamMemberModal from "./edit_membership";
|
||||
import UserSearchModal from "@/components/common_components/user_search_modal";
|
||||
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
|
||||
import ModelAliasesCard from "./model_aliases_card";
|
||||
|
||||
|
||||
interface TeamData {
|
||||
|
@ -586,7 +587,7 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
|||
/>
|
||||
</Form.Item>
|
||||
|
||||
<div className="flex justify-end gap-2">
|
||||
<div className="flex justify-end gap-2 mt-6">
|
||||
<Button onClick={() => setIsEditing(false)}>
|
||||
Cancel
|
||||
</Button>
|
||||
|
@ -638,6 +639,15 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
|||
</div>
|
||||
)}
|
||||
</Card>
|
||||
|
||||
{/* Add Model Aliases Card */}
|
||||
<ModelAliasesCard
|
||||
teamId={teamId}
|
||||
accessToken={accessToken}
|
||||
currentAliases={teamData?.team_info?.litellm_model_table?.model_aliases || {}}
|
||||
availableModels={userModels}
|
||||
onUpdate={fetchTeamInfo}
|
||||
/>
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</TabGroup>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue