forked from phoenix/litellm-mirror
fix(user_api_key_auth.py): add auth check for websocket endpoint
Fixes https://github.com/BerriAI/litellm/issues/6926
This commit is contained in:
parent
037171b98b
commit
9ec6ebaeeb
3 changed files with 58 additions and 3 deletions
|
@ -11,6 +11,10 @@ model_list:
|
|||
model: vertex_ai/claude-3-5-sonnet-v2
|
||||
vertex_ai_project: "adroit-crow-413218"
|
||||
vertex_ai_location: "us-east5"
|
||||
- model_name: openai-gpt-4o-realtime-audio
|
||||
litellm_params:
|
||||
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
|
|
|
@ -28,6 +28,8 @@ from fastapi import (
|
|||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
status,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
@ -195,6 +197,48 @@ def _is_allowed_route(
|
|||
)
|
||||
|
||||
|
||||
async def user_api_key_auth_websocket(websocket: WebSocket):
|
||||
# Accept the WebSocket connection
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = websocket.url
|
||||
|
||||
async def return_body():
|
||||
return_string = '{{"model": "fake-openai-endpoint"}}'
|
||||
# return string as bytes
|
||||
return return_string.encode()
|
||||
|
||||
request.body = return_body
|
||||
|
||||
# Extract the Authorization header
|
||||
authorization = websocket.headers.get("authorization")
|
||||
|
||||
# If no Authorization header, try the api-key header
|
||||
if not authorization:
|
||||
api_key = websocket.headers.get("api-key")
|
||||
if not api_key:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
raise HTTPException(status_code=403, detail="No API key provided")
|
||||
else:
|
||||
# Extract the API key from the Bearer token
|
||||
if not authorization.startswith("Bearer "):
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Invalid Authorization header format"
|
||||
)
|
||||
|
||||
api_key = authorization[len("Bearer ") :].strip()
|
||||
|
||||
# Call user_api_key_auth with the extracted API key
|
||||
# Note: You'll need to modify this to work with WebSocket context if needed
|
||||
try:
|
||||
return await user_api_key_auth(request=request, api_key=f"Bearer {api_key}")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(e)
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
async def user_api_key_auth( # noqa: PLR0915
|
||||
request: Request,
|
||||
api_key: str = fastapi.Security(api_key_header),
|
||||
|
|
|
@ -134,7 +134,10 @@ from litellm.proxy.auth.model_checks import (
|
|||
get_key_models,
|
||||
get_team_models,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.auth.user_api_key_auth import (
|
||||
user_api_key_auth,
|
||||
user_api_key_auth_websocket,
|
||||
)
|
||||
|
||||
## Import All Misc routes here ##
|
||||
from litellm.proxy.caching_routes import router as caching_router
|
||||
|
@ -4394,7 +4397,11 @@ from litellm import _arealtime
|
|||
|
||||
|
||||
@app.websocket("/v1/realtime")
|
||||
async def websocket_endpoint(websocket: WebSocket, model: str):
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
model: str,
|
||||
user_api_key_dict=Depends(user_api_key_auth_websocket),
|
||||
):
|
||||
import websockets
|
||||
|
||||
await websocket.accept()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue