From dd78a1956a591aadd32193289f5a53e21f1bc0f2 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 7 Mar 2024 07:56:51 -0800 Subject: [PATCH] fix(proxy_server.py): fix model alias map + add back testing --- litellm/proxy/proxy_server.py | 24 ++++++++++++++++ tests/test_team.py | 53 +++++++++++++++++++++++++++++++---- 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 95570d7fa..19ac5c961 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2541,6 +2541,12 @@ async def completion( if user_api_base: data["api_base"] = user_api_base + ### MODEL ALIAS MAPPING ### + # check if model name in model alias map + # get the actual model name + if data["model"] in litellm.model_alias_map: + data["model"] = litellm.model_alias_map[data["model"]] + ### CALL HOOKS ### - modify incoming data before calling the model data = await proxy_logging_obj.pre_call_hook( user_api_key_dict=user_api_key_dict, data=data, call_type="completion" @@ -2740,6 +2746,12 @@ async def chat_completion( if user_api_base: data["api_base"] = user_api_base + ### MODEL ALIAS MAPPING ### + # check if model name in model alias map + # get the actual model name + if data["model"] in litellm.model_alias_map: + data["model"] = litellm.model_alias_map[data["model"]] + ### CALL HOOKS ### - modify incoming data before calling the model data = await proxy_logging_obj.pre_call_hook( user_api_key_dict=user_api_key_dict, data=data, call_type="completion" @@ -2948,6 +2960,12 @@ async def embeddings( **data, } # add the team-specific configs to the completion call + ### MODEL ALIAS MAPPING ### + # check if model name in model alias map + # get the actual model name + if data["model"] in litellm.model_alias_map: + data["model"] = litellm.model_alias_map[data["model"]] + router_model_names = ( [m["model_name"] for m in llm_model_list] if llm_model_list is not None @@ -3119,6 +3137,12 @@ async def image_generation( **data, } # add the team-specific configs to the completion call + ### MODEL ALIAS MAPPING ### + # check if model name in model alias map + # get the actual model name + if data["model"] in litellm.model_alias_map: + data["model"] = litellm.model_alias_map[data["model"]] + router_model_names = ( [m["model_name"] for m in llm_model_list] if llm_model_list is not None diff --git a/tests/test_team.py b/tests/test_team.py index 15303331a..f0ef0bb22 100644 --- a/tests/test_team.py +++ b/tests/test_team.py @@ -7,11 +7,13 @@ import time, uuid from openai import AsyncOpenAI -async def new_user(session, i, user_id=None, budget=None, budget_duration=None): +async def new_user( + session, i, user_id=None, budget=None, budget_duration=None, models=["azure-models"] +): url = "http://0.0.0.0:4000/user/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = { - "models": ["azure-models"], + "models": models, "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": None, "max_budget": budget, @@ -125,17 +127,22 @@ async def chat_completion(session, key, model="gpt-4"): pass -async def new_team(session, i, user_id=None, member_list=None): +async def new_team(session, i, user_id=None, member_list=None, model_aliases=None): + import json + url = "http://0.0.0.0:4000/team/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} - data = { - "team_alias": "my-new-team", - } + data = {"team_alias": "my-new-team"} if user_id is not None: data["members_with_roles"] = [{"role": "user", "user_id": user_id}] elif member_list is not None: data["members_with_roles"] = member_list + if model_aliases is not None: + data["model_aliases"] = model_aliases + + print(f"data: {data}") + async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() @@ -351,3 +358,37 @@ async def test_member_delete(): member_id_list.append(member["user_id"]) assert normal_user not in member_id_list + + +@pytest.mark.asyncio +async def test_team_alias(): + """ + - Create team w/ model alias + - Create key for team + - Check if key works + """ + async with aiohttp.ClientSession() as session: + ## Create admin + admin_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=admin_user) + ## Create normal user + normal_user = f"{uuid.uuid4()}" + await new_user(session=session, i=0, user_id=normal_user) + ## Create team with 1 admin and 1 user + member_list = [ + {"role": "admin", "user_id": admin_user}, + {"role": "user", "user_id": normal_user}, + ] + team_data = await new_team( + session=session, + i=0, + member_list=member_list, + model_aliases={"cheap-model": "gpt-3.5-turbo"}, + ) + ## Create key + key_gen = await generate_key( + session=session, i=0, team_id=team_data["team_id"], models=["gpt-3.5-turbo"] + ) + key = key_gen["key"] + ## Test key + response = await chat_completion(session=session, key=key, model="cheap-model")