From 9ec6ebaeebd6399ad1b56a6761a4b34a7acfc54a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 27 Nov 2024 17:45:59 +0530 Subject: [PATCH] fix(user_api_key_auth.py): add auth check for websocket endpoint Fixes https://github.com/BerriAI/litellm/issues/6926 --- litellm/proxy/_new_secret_config.yaml | 6 +++- litellm/proxy/auth/user_api_key_auth.py | 44 +++++++++++++++++++++++++ litellm/proxy/proxy_server.py | 11 +++++-- 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 86ece3788..03d66351d 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -11,7 +11,11 @@ model_list: model: vertex_ai/claude-3-5-sonnet-v2 vertex_ai_project: "adroit-crow-413218" vertex_ai_location: "us-east5" - + - model_name: openai-gpt-4o-realtime-audio + litellm_params: + model: openai/gpt-4o-realtime-preview-2024-10-01 + api_key: os.environ/OPENAI_API_KEY + router_settings: routing_strategy: usage-based-routing-v2 #redis_url: "os.environ/REDIS_URL" diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 32f0c95db..99c7dc50e 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -28,6 +28,8 @@ from fastapi import ( Request, Response, UploadFile, + WebSocket, + WebSocketDisconnect, status, ) from fastapi.middleware.cors import CORSMiddleware @@ -195,6 +197,48 @@ def _is_allowed_route( ) +async def user_api_key_auth_websocket(websocket: WebSocket): + # Accept the WebSocket connection + + request = Request(scope={"type": "http"}) + request._url = websocket.url + + async def return_body(): + return_string = '{{"model": "fake-openai-endpoint"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body + + # Extract the Authorization header + authorization = websocket.headers.get("authorization") + + # If no Authorization header, try the api-key header + if not authorization: + api_key = websocket.headers.get("api-key") + if not api_key: + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException(status_code=403, detail="No API key provided") + else: + # Extract the API key from the Bearer token + if not authorization.startswith("Bearer "): + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException( + status_code=403, detail="Invalid Authorization header format" + ) + + api_key = authorization[len("Bearer ") :].strip() + + # Call user_api_key_auth with the extracted API key + # Note: You'll need to modify this to work with WebSocket context if needed + try: + return await user_api_key_auth(request=request, api_key=f"Bearer {api_key}") + except Exception as e: + verbose_proxy_logger.exception(e) + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException(status_code=403, detail=str(e)) + + async def user_api_key_auth( # noqa: PLR0915 request: Request, api_key: str = fastapi.Security(api_key_header), diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 15971263a..8e04951ad 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -134,7 +134,10 @@ from litellm.proxy.auth.model_checks import ( get_key_models, get_team_models, ) -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.auth.user_api_key_auth import ( + user_api_key_auth, + user_api_key_auth_websocket, +) ## Import All Misc routes here ## from litellm.proxy.caching_routes import router as caching_router @@ -4394,7 +4397,11 @@ from litellm import _arealtime @app.websocket("/v1/realtime") -async def websocket_endpoint(websocket: WebSocket, model: str): +async def websocket_endpoint( + websocket: WebSocket, + model: str, + user_api_key_dict=Depends(user_api_key_auth_websocket), +): import websockets await websocket.accept()