forked from phoenix/litellm-mirror
fix team based tag routing
This commit is contained in:
parent
308377fbe2
commit
da2cefc45a
3 changed files with 22 additions and 42 deletions
|
@ -1,20 +1,20 @@
|
|||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
tags: ["teamA"]
|
||||
model_info:
|
||||
id: "teama"
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
tags: ["teamB"]
|
||||
model_info:
|
||||
id: "teamb"
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
tags: ["teamA"] # 👈 Key Change
|
||||
model_info:
|
||||
id: "team-a-model" # used for identifying model in response headers
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
tags: ["teamB"] # 👈 Key Change
|
||||
model_info:
|
||||
id: "team-b-model" # used for identifying model in response headers
|
||||
- model_name: rerank-english-v3.0
|
||||
litellm_params:
|
||||
model: cohere/rerank-english-v3.0
|
||||
|
|
|
@ -379,12 +379,6 @@ async def add_litellm_data_to_request(
|
|||
# unpack callback_vars in data
|
||||
for k, v in callback_settings_obj.callback_vars.items():
|
||||
data[k] = v
|
||||
# Team based tags
|
||||
add_team_based_tags_to_metadata(
|
||||
data=data,
|
||||
_metadata_variable_name=_metadata_variable_name,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# Guardrails
|
||||
move_guardrails_to_metadata(
|
||||
|
@ -396,24 +390,6 @@ async def add_litellm_data_to_request(
|
|||
return data
|
||||
|
||||
|
||||
def add_team_based_tags_to_metadata(
|
||||
data: dict,
|
||||
_metadata_variable_name: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is True:
|
||||
if (
|
||||
user_api_key_dict.team_metadata is not None
|
||||
and "tags" in user_api_key_dict.team_metadata
|
||||
):
|
||||
_team_tags = user_api_key_dict.team_metadata["tags"]
|
||||
_tags_in_metadata = data[_metadata_variable_name].get("tags", [])
|
||||
_tags_in_metadata.extend(_team_tags)
|
||||
data[_metadata_variable_name]["tags"] = _tags_in_metadata
|
||||
|
||||
|
||||
def move_guardrails_to_metadata(
|
||||
data: dict,
|
||||
_metadata_variable_name: str,
|
||||
|
|
|
@ -109,7 +109,9 @@ async def test_team_tag_routing():
|
|||
headers = dict(headers)
|
||||
print(response_a)
|
||||
print(headers)
|
||||
assert headers["x-litellm-model-id"] == "teama", "Model ID should be teamA"
|
||||
assert (
|
||||
headers["x-litellm-model-id"] == "team-a-model"
|
||||
), "Model ID should be teamA"
|
||||
|
||||
key_with_team_b = await create_key_with_team(session, key, team_b_id)
|
||||
_key_with_team_b = key_with_team_b["key"]
|
||||
|
@ -118,7 +120,9 @@ async def test_team_tag_routing():
|
|||
headers = dict(headers)
|
||||
print(response_b)
|
||||
print(headers)
|
||||
assert headers["x-litellm-model-id"] == "teamb", "Model ID should be teamB"
|
||||
assert (
|
||||
headers["x-litellm-model-id"] == "team-b-model"
|
||||
), "Model ID should be teamB"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue