Merge pull request #1971 from BerriAI/litellm_fix_team_id

fix(proxy_server.py): fix proxy server team id bug
This commit is contained in:
Krish Dholakia 2024-02-13 23:24:38 -08:00 committed by GitHub
commit dde78d8f4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 41 additions and 31 deletions

View file

@ -682,35 +682,31 @@ async def user_api_key_auth(
# sso/login, ui/login, /key functions and /user functions
# this will never be allowed to call /chat/completions
token_team = getattr(valid_token, "team_id", None)
if token_team is not None:
if token_team == "litellm-dashboard":
# this token is only used for managing the ui
allowed_routes = [
"/sso",
"/login",
"/key",
"/spend",
"/user",
"/model/info",
]
# check if the current route startswith any of the allowed routes
if (
route is not None
and isinstance(route, str)
and any(
route.startswith(allowed_route)
for allowed_route in allowed_routes
)
):
# Do something if the current route starts with any of the allowed routes
pass
else:
raise Exception(
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception(f"Invalid Key Passed to LiteLLM Proxy")
if token_team is not None and token_team == "litellm-dashboard":
# this token is only used for managing the ui
allowed_routes = [
"/sso",
"/login",
"/key",
"/spend",
"/user",
"/model/info",
]
# check if the current route startswith any of the allowed routes
if (
route is not None
and isinstance(route, str)
and any(
route.startswith(allowed_route) for allowed_route in allowed_routes
)
):
# Do something if the current route starts with any of the allowed routes
pass
else:
raise Exception(
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
except Exception as e:
# verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}")
traceback.print_exc()
@ -1599,8 +1595,6 @@ async def generate_key_helper_fn(
tpm_limit = tpm_limit
rpm_limit = rpm_limit
allowed_cache_controls = allowed_cache_controls
if type(team_id) is not str:
team_id = str(team_id)
try:
# Create a new verification token (you may want to enhance this logic based on your needs)
user_data = {

View file

@ -88,6 +88,22 @@ async def test_chat_completion():
await chat_completion(session=session, key=key_2)
@pytest.mark.asyncio
async def test_chat_completion_old_key():
"""
Production test for backwards compatibility. Test db against a pre-generated (old key)
- Create key
Make chat completion call
"""
async with aiohttp.ClientSession() as session:
try:
key = "sk-yNXvlRO4SxIGG0XnRMYxTw"
await chat_completion(session=session, key=key)
except Exception as e:
key = "sk-2KV0sAElLQqMpLZXdNf3yw" # try diff db key (in case db url is for the other db)
await chat_completion(session=session, key=key)
async def completion(session, key):
url = "http://0.0.0.0:4000/completions"
headers = {