fix(proxy_server.py): fix proxy server team id bug

This commit is contained in:
Krrish Dholakia 2024-02-13 22:33:56 -08:00
parent dc0b2b4501
commit 83d43809a7
2 changed files with 49 additions and 37 deletions

View file

@ -166,9 +166,9 @@ class ProxyException(Exception):
async def openai_exception_handler(request: Request, exc: ProxyException): async def openai_exception_handler(request: Request, exc: ProxyException):
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions # NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
return JSONResponse( return JSONResponse(
status_code=int(exc.code) status_code=(
if exc.code int(exc.code) if exc.code else status.HTTP_500_INTERNAL_SERVER_ERROR
else status.HTTP_500_INTERNAL_SERVER_ERROR, ),
content={ content={
"error": { "error": {
"message": exc.message, "message": exc.message,
@ -682,35 +682,31 @@ async def user_api_key_auth(
# sso/login, ui/login, /key functions and /user functions # sso/login, ui/login, /key functions and /user functions
# this will never be allowed to call /chat/completions # this will never be allowed to call /chat/completions
token_team = getattr(valid_token, "team_id", None) token_team = getattr(valid_token, "team_id", None)
if token_team is not None: if token_team is not None and token_team == "litellm-dashboard":
if token_team == "litellm-dashboard": # this token is only used for managing the ui
# this token is only used for managing the ui allowed_routes = [
allowed_routes = [ "/sso",
"/sso", "/login",
"/login", "/key",
"/key", "/spend",
"/spend", "/user",
"/user", "/model/info",
"/model/info", ]
] # check if the current route startswith any of the allowed routes
# check if the current route startswith any of the allowed routes if (
if ( route is not None
route is not None and isinstance(route, str)
and isinstance(route, str) and any(
and any( route.startswith(allowed_route) for allowed_route in allowed_routes
route.startswith(allowed_route) )
for allowed_route in allowed_routes ):
) # Do something if the current route starts with any of the allowed routes
): pass
# Do something if the current route starts with any of the allowed routes else:
pass raise Exception(
else: f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
raise Exception( )
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed" return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception(f"Invalid Key Passed to LiteLLM Proxy")
except Exception as e: except Exception as e:
# verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}") # verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}")
traceback.print_exc() traceback.print_exc()
@ -1599,8 +1595,6 @@ async def generate_key_helper_fn(
tpm_limit = tpm_limit tpm_limit = tpm_limit
rpm_limit = rpm_limit rpm_limit = rpm_limit
allowed_cache_controls = allowed_cache_controls allowed_cache_controls = allowed_cache_controls
if type(team_id) is not str:
team_id = str(team_id)
try: try:
# Create a new verification token (you may want to enhance this logic based on your needs) # Create a new verification token (you may want to enhance this logic based on your needs)
user_data = { user_data = {
@ -4553,9 +4547,11 @@ async def get_routes():
"path": getattr(route, "path", None), "path": getattr(route, "path", None),
"methods": getattr(route, "methods", None), "methods": getattr(route, "methods", None),
"name": getattr(route, "name", None), "name": getattr(route, "name", None),
"endpoint": getattr(route, "endpoint", None).__name__ "endpoint": (
if getattr(route, "endpoint", None) getattr(route, "endpoint", None).__name__
else None, if getattr(route, "endpoint", None)
else None
),
} }
routes.append(route_info) routes.append(route_info)

View file

@ -88,6 +88,22 @@ async def test_chat_completion():
await chat_completion(session=session, key=key_2) await chat_completion(session=session, key=key_2)
@pytest.mark.asyncio
async def test_chat_completion_old_key():
"""
Production test for backwards compatibility. Test db against a pre-generated (old key)
- Create key
Make chat completion call
"""
async with aiohttp.ClientSession() as session:
try:
key = "sk-yNXvlRO4SxIGG0XnRMYxTw"
await chat_completion(session=session, key=key)
except Exception as e:
key = "sk-2KV0sAElLQqMpLZXdNf3yw" # try diff db key (in case db url is for the other db)
await chat_completion(session=session, key=key)
async def completion(session, key): async def completion(session, key):
url = "http://0.0.0.0:4000/completions" url = "http://0.0.0.0:4000/completions"
headers = { headers = {