fix(proxy_server.py): fix model alias map + add back testing

This commit is contained in:
Krrish Dholakia 2024-03-07 07:56:51 -08:00
parent b9854a99d2
commit dd78a1956a
2 changed files with 71 additions and 6 deletions

View file

@ -2541,6 +2541,12 @@ async def completion(
if user_api_base:
data["api_base"] = user_api_base
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
@ -2740,6 +2746,12 @@ async def chat_completion(
if user_api_base:
data["api_base"] = user_api_base
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
@ -2948,6 +2960,12 @@ async def embeddings(
**data,
} # add the team-specific configs to the completion call
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
@ -3119,6 +3137,12 @@ async def image_generation(
**data,
} # add the team-specific configs to the completion call
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None

View file

@ -7,11 +7,13 @@ import time, uuid
from openai import AsyncOpenAI
async def new_user(session, i, user_id=None, budget=None, budget_duration=None):
async def new_user(
session, i, user_id=None, budget=None, budget_duration=None, models=["azure-models"]
):
url = "http://0.0.0.0:4000/user/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"models": ["azure-models"],
"models": models,
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
"max_budget": budget,
@ -125,17 +127,22 @@ async def chat_completion(session, key, model="gpt-4"):
pass
async def new_team(session, i, user_id=None, member_list=None):
async def new_team(session, i, user_id=None, member_list=None, model_aliases=None):
import json
url = "http://0.0.0.0:4000/team/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"team_alias": "my-new-team",
}
data = {"team_alias": "my-new-team"}
if user_id is not None:
data["members_with_roles"] = [{"role": "user", "user_id": user_id}]
elif member_list is not None:
data["members_with_roles"] = member_list
if model_aliases is not None:
data["model_aliases"] = model_aliases
print(f"data: {data}")
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
@ -351,3 +358,37 @@ async def test_member_delete():
member_id_list.append(member["user_id"])
assert normal_user not in member_id_list
@pytest.mark.asyncio
async def test_team_alias():
"""
- Create team w/ model alias
- Create key for team
- Check if key works
"""
async with aiohttp.ClientSession() as session:
## Create admin
admin_user = f"{uuid.uuid4()}"
await new_user(session=session, i=0, user_id=admin_user)
## Create normal user
normal_user = f"{uuid.uuid4()}"
await new_user(session=session, i=0, user_id=normal_user)
## Create team with 1 admin and 1 user
member_list = [
{"role": "admin", "user_id": admin_user},
{"role": "user", "user_id": normal_user},
]
team_data = await new_team(
session=session,
i=0,
member_list=member_list,
model_aliases={"cheap-model": "gpt-3.5-turbo"},
)
## Create key
key_gen = await generate_key(
session=session, i=0, team_id=team_data["team_id"], models=["gpt-3.5-turbo"]
)
key = key_gen["key"]
## Test key
response = await chat_completion(session=session, key=key, model="cheap-model")