fix team based tag routing

This commit is contained in:
Ishaan Jaff 2024-08-29 14:37:44 -07:00
parent 308377fbe2
commit da2cefc45a
3 changed files with 22 additions and 42 deletions

View file

@ -1,20 +1,20 @@
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
tags: ["teamA"]
model_info:
id: "teama"
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
tags: ["teamB"]
model_info:
id: "teamb"
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
tags: ["teamA"] # 👈 Key Change
model_info:
id: "team-a-model" # used for identifying model in response headers
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
tags: ["teamB"] # 👈 Key Change
model_info:
id: "team-b-model" # used for identifying model in response headers
- model_name: rerank-english-v3.0
litellm_params:
model: cohere/rerank-english-v3.0

View file

@ -379,12 +379,6 @@ async def add_litellm_data_to_request(
# unpack callback_vars in data
for k, v in callback_settings_obj.callback_vars.items():
data[k] = v
# Team based tags
add_team_based_tags_to_metadata(
data=data,
_metadata_variable_name=_metadata_variable_name,
user_api_key_dict=user_api_key_dict,
)
# Guardrails
move_guardrails_to_metadata(
@ -396,24 +390,6 @@ async def add_litellm_data_to_request(
return data
def add_team_based_tags_to_metadata(
data: dict,
_metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth,
):
from litellm.proxy.proxy_server import premium_user
if premium_user is True:
if (
user_api_key_dict.team_metadata is not None
and "tags" in user_api_key_dict.team_metadata
):
_team_tags = user_api_key_dict.team_metadata["tags"]
_tags_in_metadata = data[_metadata_variable_name].get("tags", [])
_tags_in_metadata.extend(_team_tags)
data[_metadata_variable_name]["tags"] = _tags_in_metadata
def move_guardrails_to_metadata(
data: dict,
_metadata_variable_name: str,

View file

@ -109,7 +109,9 @@ async def test_team_tag_routing():
headers = dict(headers)
print(response_a)
print(headers)
assert headers["x-litellm-model-id"] == "teama", "Model ID should be teamA"
assert (
headers["x-litellm-model-id"] == "team-a-model"
), "Model ID should be teamA"
key_with_team_b = await create_key_with_team(session, key, team_b_id)
_key_with_team_b = key_with_team_b["key"]
@ -118,7 +120,9 @@ async def test_team_tag_routing():
headers = dict(headers)
print(response_b)
print(headers)
assert headers["x-litellm-model-id"] == "teamb", "Model ID should be teamB"
assert (
headers["x-litellm-model-id"] == "team-b-model"
), "Model ID should be teamB"
@pytest.mark.asyncio()