mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
(round 4 fixes) - Team model alias setting (#8474)
* update team info endpoint * clean up model alias * fix model alias * fix model alias card * clean up naming on docs * fix model alias card * fix _model_in_team_aliases * team alias - fix litellm.model_alias_map * fix _update_model_if_team_alias_exists * fix test_aview_spend_per_user * Test model alias functionality with teams: * complete e2e test * test_update_model_if_team_alias_exists
This commit is contained in:
parent
51419338df
commit
a449eb1dc9
4 changed files with 168 additions and 15 deletions
|
@ -790,21 +790,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
raise Exception(
|
||||
"Key is blocked. Update via `/key/unblock` if you're admin."
|
||||
)
|
||||
|
||||
# Check 1. If token can call model
|
||||
_model_alias_map = {}
|
||||
model: Optional[str] = None
|
||||
if (
|
||||
hasattr(valid_token, "team_model_aliases")
|
||||
and valid_token.team_model_aliases is not None
|
||||
):
|
||||
_model_alias_map = {
|
||||
**valid_token.aliases,
|
||||
**valid_token.team_model_aliases,
|
||||
}
|
||||
else:
|
||||
_model_alias_map = {**valid_token.aliases}
|
||||
litellm.model_alias_map = _model_alias_map
|
||||
config = valid_token.config
|
||||
|
||||
if config != {}:
|
||||
|
|
|
@ -635,6 +635,12 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# Team Model Aliases
|
||||
_update_model_if_team_alias_exists(
|
||||
data=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"[PROXY] returned data from litellm_pre_call_utils: %s", data
|
||||
)
|
||||
|
@ -664,6 +670,32 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
return data
|
||||
|
||||
|
||||
def _update_model_if_team_alias_exists(
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> None:
|
||||
"""
|
||||
Update the model if the team alias exists
|
||||
|
||||
If a alias map has been set on a team, then we want to make the request with the model the team alias is pointing to
|
||||
|
||||
eg.
|
||||
- user calls `gpt-4o`
|
||||
- team.model_alias_map = {
|
||||
"gpt-4o": "gpt-4o-team-1"
|
||||
}
|
||||
- requested_model = "gpt-4o-team-1"
|
||||
"""
|
||||
_model = data.get("model")
|
||||
if (
|
||||
_model
|
||||
and user_api_key_dict.team_model_aliases
|
||||
and _model in user_api_key_dict.team_model_aliases
|
||||
):
|
||||
data["model"] = user_api_key_dict.team_model_aliases[_model]
|
||||
return
|
||||
|
||||
|
||||
def _get_enforced_params(
|
||||
general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth
|
||||
) -> Optional[list]:
|
||||
|
|
|
@ -1617,3 +1617,53 @@ def test_provider_specific_header():
|
|||
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data, user_api_key_dict, expected_model",
|
||||
[
|
||||
# Test case 1: Model exists in team aliases
|
||||
(
|
||||
{"model": "gpt-4o"},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_key", team_model_aliases={"gpt-4o": "gpt-4o-team-1"}
|
||||
),
|
||||
"gpt-4o-team-1",
|
||||
),
|
||||
# Test case 2: Model doesn't exist in team aliases
|
||||
(
|
||||
{"model": "gpt-4o"},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_key", team_model_aliases={"claude-3": "claude-3-team-1"}
|
||||
),
|
||||
"gpt-4o",
|
||||
),
|
||||
# Test case 3: No team aliases defined
|
||||
(
|
||||
{"model": "gpt-4o"},
|
||||
UserAPIKeyAuth(api_key="test_key", team_model_aliases=None),
|
||||
"gpt-4o",
|
||||
),
|
||||
# Test case 4: No model in request data
|
||||
(
|
||||
{"messages": []},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_key", team_model_aliases={"gpt-4o": "gpt-4o-team-1"}
|
||||
),
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_model):
|
||||
from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists
|
||||
|
||||
# Make a copy of the input data to avoid modifying the test parameters
|
||||
test_data = data.copy()
|
||||
|
||||
# Call the function
|
||||
_update_model_if_team_alias_exists(
|
||||
data=test_data, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
|
||||
# Check if model was updated correctly
|
||||
assert test_data.get("model") == expected_model
|
||||
|
|
86
tests/store_model_in_db_tests/test_team_alias.py
Normal file
86
tests/store_model_in_db_tests/test_team_alias.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
from openai import AsyncOpenAI
|
||||
import uuid
|
||||
from httpx import AsyncClient
|
||||
import uuid
|
||||
import os
|
||||
|
||||
TEST_MASTER_KEY = "sk-1234"
|
||||
PROXY_BASE_URL = "http://0.0.0.0:4000"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_alias():
|
||||
"""
|
||||
Test model alias functionality with teams:
|
||||
1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o"
|
||||
2. Create a new team
|
||||
3. Update team with model_alias mapping
|
||||
4. Generate key for team
|
||||
5. Make request with aliased model name
|
||||
"""
|
||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||
|
||||
# Add new model
|
||||
model_response = await client.post(
|
||||
"/model/new",
|
||||
json={
|
||||
"model_name": "gpt-4o-team1",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert model_response.status_code == 200
|
||||
|
||||
# Create new team
|
||||
team_response = await client.post(
|
||||
"/team/new",
|
||||
json={
|
||||
"models": ["gpt-4o-team1"],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert team_response.status_code == 200
|
||||
team_data = team_response.json()
|
||||
team_id = team_data["team_id"]
|
||||
|
||||
# Update team with model alias
|
||||
update_response = await client.post(
|
||||
"/team/update",
|
||||
json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}},
|
||||
headers=headers,
|
||||
)
|
||||
assert update_response.status_code == 200
|
||||
|
||||
# Generate key for team
|
||||
key_response = await client.post(
|
||||
"/key/generate", json={"team_id": team_id}, headers=headers
|
||||
)
|
||||
assert key_response.status_code == 200
|
||||
key = key_response.json()["key"]
|
||||
|
||||
# Make request with model alias
|
||||
openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1")
|
||||
|
||||
response = await openai_client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}],
|
||||
)
|
||||
|
||||
assert response is not None, "Should get valid response when using model alias"
|
||||
|
||||
# Cleanup - delete the model
|
||||
model_id = model_response.json()["model_info"]["id"]
|
||||
delete_response = await client.post(
|
||||
"/model/delete",
|
||||
json={"id": model_id},
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
)
|
||||
assert delete_response.status_code == 200
|
Loading…
Add table
Add a link
Reference in a new issue