mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
(round 4 fixes) - Team model alias setting (#8474)
* update team info endpoint * clean up model alias * fix model alias * fix model alias card * clean up naming on docs * fix model alias card * fix _model_in_team_aliases * team alias - fix litellm.model_alias_map * fix _update_model_if_team_alias_exists * fix test_aview_spend_per_user * Test model alias functionality with teams: * complete e2e test * test_update_model_if_team_alias_exists
This commit is contained in:
parent
51419338df
commit
a449eb1dc9
4 changed files with 168 additions and 15 deletions
|
@ -790,21 +790,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Key is blocked. Update via `/key/unblock` if you're admin."
|
"Key is blocked. Update via `/key/unblock` if you're admin."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check 1. If token can call model
|
|
||||||
_model_alias_map = {}
|
|
||||||
model: Optional[str] = None
|
|
||||||
if (
|
|
||||||
hasattr(valid_token, "team_model_aliases")
|
|
||||||
and valid_token.team_model_aliases is not None
|
|
||||||
):
|
|
||||||
_model_alias_map = {
|
|
||||||
**valid_token.aliases,
|
|
||||||
**valid_token.team_model_aliases,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
_model_alias_map = {**valid_token.aliases}
|
|
||||||
litellm.model_alias_map = _model_alias_map
|
|
||||||
config = valid_token.config
|
config = valid_token.config
|
||||||
|
|
||||||
if config != {}:
|
if config != {}:
|
||||||
|
|
|
@ -635,6 +635,12 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Team Model Aliases
|
||||||
|
_update_model_if_team_alias_exists(
|
||||||
|
data=data,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"[PROXY] returned data from litellm_pre_call_utils: %s", data
|
"[PROXY] returned data from litellm_pre_call_utils: %s", data
|
||||||
)
|
)
|
||||||
|
@ -664,6 +670,32 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _update_model_if_team_alias_exists(
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update the model if the team alias exists
|
||||||
|
|
||||||
|
If a alias map has been set on a team, then we want to make the request with the model the team alias is pointing to
|
||||||
|
|
||||||
|
eg.
|
||||||
|
- user calls `gpt-4o`
|
||||||
|
- team.model_alias_map = {
|
||||||
|
"gpt-4o": "gpt-4o-team-1"
|
||||||
|
}
|
||||||
|
- requested_model = "gpt-4o-team-1"
|
||||||
|
"""
|
||||||
|
_model = data.get("model")
|
||||||
|
if (
|
||||||
|
_model
|
||||||
|
and user_api_key_dict.team_model_aliases
|
||||||
|
and _model in user_api_key_dict.team_model_aliases
|
||||||
|
):
|
||||||
|
data["model"] = user_api_key_dict.team_model_aliases[_model]
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def _get_enforced_params(
|
def _get_enforced_params(
|
||||||
general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth
|
general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth
|
||||||
) -> Optional[list]:
|
) -> Optional[list]:
|
||||||
|
|
|
@ -1617,3 +1617,53 @@ def test_provider_specific_header():
|
||||||
"anthropic-beta": "prompt-caching-2024-07-31",
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data, user_api_key_dict, expected_model",
|
||||||
|
[
|
||||||
|
# Test case 1: Model exists in team aliases
|
||||||
|
(
|
||||||
|
{"model": "gpt-4o"},
|
||||||
|
UserAPIKeyAuth(
|
||||||
|
api_key="test_key", team_model_aliases={"gpt-4o": "gpt-4o-team-1"}
|
||||||
|
),
|
||||||
|
"gpt-4o-team-1",
|
||||||
|
),
|
||||||
|
# Test case 2: Model doesn't exist in team aliases
|
||||||
|
(
|
||||||
|
{"model": "gpt-4o"},
|
||||||
|
UserAPIKeyAuth(
|
||||||
|
api_key="test_key", team_model_aliases={"claude-3": "claude-3-team-1"}
|
||||||
|
),
|
||||||
|
"gpt-4o",
|
||||||
|
),
|
||||||
|
# Test case 3: No team aliases defined
|
||||||
|
(
|
||||||
|
{"model": "gpt-4o"},
|
||||||
|
UserAPIKeyAuth(api_key="test_key", team_model_aliases=None),
|
||||||
|
"gpt-4o",
|
||||||
|
),
|
||||||
|
# Test case 4: No model in request data
|
||||||
|
(
|
||||||
|
{"messages": []},
|
||||||
|
UserAPIKeyAuth(
|
||||||
|
api_key="test_key", team_model_aliases={"gpt-4o": "gpt-4o-team-1"}
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_model):
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists
|
||||||
|
|
||||||
|
# Make a copy of the input data to avoid modifying the test parameters
|
||||||
|
test_data = data.copy()
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
_update_model_if_team_alias_exists(
|
||||||
|
data=test_data, user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if model was updated correctly
|
||||||
|
assert test_data.get("model") == expected_model
|
||||||
|
|
86
tests/store_model_in_db_tests/test_team_alias.py
Normal file
86
tests/store_model_in_db_tests/test_team_alias.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import uuid
|
||||||
|
from httpx import AsyncClient
|
||||||
|
import uuid
|
||||||
|
import os
|
||||||
|
|
||||||
|
TEST_MASTER_KEY = "sk-1234"
|
||||||
|
PROXY_BASE_URL = "http://0.0.0.0:4000"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_model_alias():
|
||||||
|
"""
|
||||||
|
Test model alias functionality with teams:
|
||||||
|
1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o"
|
||||||
|
2. Create a new team
|
||||||
|
3. Update team with model_alias mapping
|
||||||
|
4. Generate key for team
|
||||||
|
5. Make request with aliased model name
|
||||||
|
"""
|
||||||
|
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||||
|
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||||
|
|
||||||
|
# Add new model
|
||||||
|
model_response = await client.post(
|
||||||
|
"/model/new",
|
||||||
|
json={
|
||||||
|
"model_name": "gpt-4o-team1",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert model_response.status_code == 200
|
||||||
|
|
||||||
|
# Create new team
|
||||||
|
team_response = await client.post(
|
||||||
|
"/team/new",
|
||||||
|
json={
|
||||||
|
"models": ["gpt-4o-team1"],
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert team_response.status_code == 200
|
||||||
|
team_data = team_response.json()
|
||||||
|
team_id = team_data["team_id"]
|
||||||
|
|
||||||
|
# Update team with model alias
|
||||||
|
update_response = await client.post(
|
||||||
|
"/team/update",
|
||||||
|
json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert update_response.status_code == 200
|
||||||
|
|
||||||
|
# Generate key for team
|
||||||
|
key_response = await client.post(
|
||||||
|
"/key/generate", json={"team_id": team_id}, headers=headers
|
||||||
|
)
|
||||||
|
assert key_response.status_code == 200
|
||||||
|
key = key_response.json()["key"]
|
||||||
|
|
||||||
|
# Make request with model alias
|
||||||
|
openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1")
|
||||||
|
|
||||||
|
response = await openai_client.chat.completions.create(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None, "Should get valid response when using model alias"
|
||||||
|
|
||||||
|
# Cleanup - delete the model
|
||||||
|
model_id = model_response.json()["model_info"]["id"]
|
||||||
|
delete_response = await client.post(
|
||||||
|
"/model/delete",
|
||||||
|
json={"id": model_id},
|
||||||
|
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||||
|
)
|
||||||
|
assert delete_response.status_code == 200
|
Loading…
Add table
Add a link
Reference in a new issue