(round 4 fixes) - Team model alias setting (#8474)

* update team info endpoint

* clean up model alias

* fix model alias

* fix model alias card

* clean up naming on docs

* fix model alias card

* fix _model_in_team_aliases

* team alias - fix litellm.model_alias_map

* fix _update_model_if_team_alias_exists

* fix test_aview_spend_per_user

* Test model alias functionality with teams:

* complete e2e test

* test_update_model_if_team_alias_exists
This commit is contained in:
Ishaan Jaff 2025-02-11 16:40:01 -08:00 committed by GitHub
parent 51419338df
commit a449eb1dc9
4 changed files with 168 additions and 15 deletions

View file

@ -790,21 +790,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
raise Exception( raise Exception(
"Key is blocked. Update via `/key/unblock` if you're admin." "Key is blocked. Update via `/key/unblock` if you're admin."
) )
# Check 1. If token can call model
_model_alias_map = {}
model: Optional[str] = None
if (
hasattr(valid_token, "team_model_aliases")
and valid_token.team_model_aliases is not None
):
_model_alias_map = {
**valid_token.aliases,
**valid_token.team_model_aliases,
}
else:
_model_alias_map = {**valid_token.aliases}
litellm.model_alias_map = _model_alias_map
config = valid_token.config config = valid_token.config
if config != {}: if config != {}:

View file

@ -635,6 +635,12 @@ async def add_litellm_data_to_request( # noqa: PLR0915
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
) )
# Team Model Aliases
_update_model_if_team_alias_exists(
data=data,
user_api_key_dict=user_api_key_dict,
)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"[PROXY] returned data from litellm_pre_call_utils: %s", data "[PROXY] returned data from litellm_pre_call_utils: %s", data
) )
@ -664,6 +670,32 @@ async def add_litellm_data_to_request( # noqa: PLR0915
return data return data
def _update_model_if_team_alias_exists(
data: dict,
user_api_key_dict: UserAPIKeyAuth,
) -> None:
"""
Update the model if the team alias exists
If a alias map has been set on a team, then we want to make the request with the model the team alias is pointing to
eg.
- user calls `gpt-4o`
- team.model_alias_map = {
"gpt-4o": "gpt-4o-team-1"
}
- requested_model = "gpt-4o-team-1"
"""
_model = data.get("model")
if (
_model
and user_api_key_dict.team_model_aliases
and _model in user_api_key_dict.team_model_aliases
):
data["model"] = user_api_key_dict.team_model_aliases[_model]
return
def _get_enforced_params( def _get_enforced_params(
general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth
) -> Optional[list]: ) -> Optional[list]:

View file

@ -1617,3 +1617,53 @@ def test_provider_specific_header():
"anthropic-beta": "prompt-caching-2024-07-31", "anthropic-beta": "prompt-caching-2024-07-31",
}, },
} }
@pytest.mark.parametrize(
"data, user_api_key_dict, expected_model",
[
# Test case 1: Model exists in team aliases
(
{"model": "gpt-4o"},
UserAPIKeyAuth(
api_key="test_key", team_model_aliases={"gpt-4o": "gpt-4o-team-1"}
),
"gpt-4o-team-1",
),
# Test case 2: Model doesn't exist in team aliases
(
{"model": "gpt-4o"},
UserAPIKeyAuth(
api_key="test_key", team_model_aliases={"claude-3": "claude-3-team-1"}
),
"gpt-4o",
),
# Test case 3: No team aliases defined
(
{"model": "gpt-4o"},
UserAPIKeyAuth(api_key="test_key", team_model_aliases=None),
"gpt-4o",
),
# Test case 4: No model in request data
(
{"messages": []},
UserAPIKeyAuth(
api_key="test_key", team_model_aliases={"gpt-4o": "gpt-4o-team-1"}
),
None,
),
],
)
def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_model):
from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists
# Make a copy of the input data to avoid modifying the test parameters
test_data = data.copy()
# Call the function
_update_model_if_team_alias_exists(
data=test_data, user_api_key_dict=user_api_key_dict
)
# Check if model was updated correctly
assert test_data.get("model") == expected_model

View file

@ -0,0 +1,86 @@
import pytest
import asyncio
import aiohttp
import json
from openai import AsyncOpenAI
import uuid
from httpx import AsyncClient
import uuid
import os
TEST_MASTER_KEY = "sk-1234"
PROXY_BASE_URL = "http://0.0.0.0:4000"
@pytest.mark.asyncio
async def test_team_model_alias():
"""
Test model alias functionality with teams:
1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o"
2. Create a new team
3. Update team with model_alias mapping
4. Generate key for team
5. Make request with aliased model name
"""
client = AsyncClient(base_url=PROXY_BASE_URL)
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
# Add new model
model_response = await client.post(
"/model/new",
json={
"model_name": "gpt-4o-team1",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
headers=headers,
)
assert model_response.status_code == 200
# Create new team
team_response = await client.post(
"/team/new",
json={
"models": ["gpt-4o-team1"],
},
headers=headers,
)
assert team_response.status_code == 200
team_data = team_response.json()
team_id = team_data["team_id"]
# Update team with model alias
update_response = await client.post(
"/team/update",
json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}},
headers=headers,
)
assert update_response.status_code == 200
# Generate key for team
key_response = await client.post(
"/key/generate", json={"team_id": team_id}, headers=headers
)
assert key_response.status_code == 200
key = key_response.json()["key"]
# Make request with model alias
openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1")
response = await openai_client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}],
)
assert response is not None, "Should get valid response when using model alias"
# Cleanup - delete the model
model_id = model_response.json()["model_info"]["id"]
delete_response = await client.post(
"/model/delete",
json={"id": model_id},
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
)
assert delete_response.status_code == 200