diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 45d6a2296..3dedc3a71 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -682,35 +682,31 @@ async def user_api_key_auth( # sso/login, ui/login, /key functions and /user functions # this will never be allowed to call /chat/completions token_team = getattr(valid_token, "team_id", None) - if token_team is not None: - if token_team == "litellm-dashboard": - # this token is only used for managing the ui - allowed_routes = [ - "/sso", - "/login", - "/key", - "/spend", - "/user", - "/model/info", - ] - # check if the current route startswith any of the allowed routes - if ( - route is not None - and isinstance(route, str) - and any( - route.startswith(allowed_route) - for allowed_route in allowed_routes - ) - ): - # Do something if the current route starts with any of the allowed routes - pass - else: - raise Exception( - f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed" - ) - return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) - else: - raise Exception(f"Invalid Key Passed to LiteLLM Proxy") + if token_team is not None and token_team == "litellm-dashboard": + # this token is only used for managing the ui + allowed_routes = [ + "/sso", + "/login", + "/key", + "/spend", + "/user", + "/model/info", + ] + # check if the current route startswith any of the allowed routes + if ( + route is not None + and isinstance(route, str) + and any( + route.startswith(allowed_route) for allowed_route in allowed_routes + ) + ): + # Do something if the current route starts with any of the allowed routes + pass + else: + raise Exception( + f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed" + ) + return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) except Exception as e: # verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}") traceback.print_exc() @@ -1599,8 +1595,6 @@ async def generate_key_helper_fn( tpm_limit = tpm_limit rpm_limit = rpm_limit allowed_cache_controls = allowed_cache_controls - if type(team_id) is not str: - team_id = str(team_id) try: # Create a new verification token (you may want to enhance this logic based on your needs) user_data = { diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 67d7c4db9..fefb4243f 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -88,6 +88,22 @@ async def test_chat_completion(): await chat_completion(session=session, key=key_2) +@pytest.mark.asyncio +async def test_chat_completion_old_key(): + """ + Production test for backwards compatibility. Test db against a pre-generated (old key) + - Create key + Make chat completion call + """ + async with aiohttp.ClientSession() as session: + try: + key = "sk-yNXvlRO4SxIGG0XnRMYxTw" + await chat_completion(session=session, key=key) + except Exception as e: + key = "sk-2KV0sAElLQqMpLZXdNf3yw" # try diff db key (in case db url is for the other db) + await chat_completion(session=session, key=key) + + async def completion(session, key): url = "http://0.0.0.0:4000/completions" headers = {