llama-stack-mirror/llama_stack/distribution/server/auth_routes.py
Eric Huang 94282d3482 gh auth, commit
# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
2025-06-26 09:10:46 -07:00

109 lines
4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import secrets
from datetime import UTC, datetime
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import JSONResponse, RedirectResponse
from llama_stack.distribution.datatypes import GitHubAuthConfig
from llama_stack.log import get_logger
from .github_oauth_auth_provider import (
GITHUB_CALLBACK_PATH,
GITHUB_LOGIN_PATH,
GitHubAuthProvider,
)
logger = get_logger(name=__name__, category="auth_routes")
def create_github_auth_router(config: GitHubAuthConfig) -> APIRouter:
"""Create and configure GitHub authentication router."""
auth_provider = GitHubAuthProvider(config)
router = APIRouter()
oauth_states: dict[str, dict] = {}
def cleanup_expired_states() -> None:
"""Remove expired OAuth states."""
now = datetime.now(UTC)
expired_states = [
state
for state, data in oauth_states.items()
if (now - data["created_at"]).seconds > 300 # 5 minutes
]
for state in expired_states:
oauth_states.pop(state, None)
@router.get(GITHUB_LOGIN_PATH)
async def github_login(request: Request, redirect_url: str | None = None):
"""Initiate GitHub OAuth flow."""
cleanup_expired_states()
# Validate redirect URL if provided
if redirect_url:
if not auth_provider.config.allowed_redirect_urls:
raise HTTPException(status_code=400, detail="Redirect URLs not configured")
if redirect_url not in auth_provider.config.allowed_redirect_urls:
raise HTTPException(status_code=400, detail="Invalid redirect URL")
state = secrets.token_urlsafe(32)
oauth_states[state] = {
"created_at": datetime.now(UTC),
"redirect_url": redirect_url,
}
# Get authorization URL
auth_url = auth_provider.get_authorization_url(state, request)
logger.debug(f"Redirecting to GitHub OAuth: {auth_url}")
return RedirectResponse(url=auth_url)
@router.get(GITHUB_CALLBACK_PATH)
async def github_callback(code: str, state: str, request: Request):
"""Handle GitHub OAuth callback."""
# Validate state parameter
if state not in oauth_states:
logger.warning(f"Invalid OAuth state received: {state}")
raise HTTPException(status_code=400, detail="Invalid state parameter")
state_data = oauth_states.pop(state)
# Check state expiry
if (datetime.now(UTC) - state_data["created_at"]).seconds > 300:
logger.warning("OAuth state expired")
raise HTTPException(status_code=400, detail="State expired")
try:
# Complete OAuth flow
access_token = await auth_provider.complete_oauth_flow(code, request)
logger.info("GitHub OAuth successful")
redirect_url = state_data.get("redirect_url")
if redirect_url:
# Validate redirect URL from state
if redirect_url not in auth_provider.config.allowed_redirect_urls:
raise HTTPException(status_code=400, detail="Invalid redirect URL in state")
import urllib.parse
params = urllib.parse.urlencode({"token": access_token})
return RedirectResponse(url=f"{redirect_url}#{params}")
else:
# For API calls, only return the access token
return JSONResponse(content={"access_token": access_token})
except ValueError as e:
logger.error(f"GitHub OAuth error: {e}")
raise HTTPException(status_code=400, detail=str(e)) from e
except Exception as e:
logger.exception("Unexpected error during GitHub OAuth")
raise HTTPException(status_code=500, detail="Authentication failed") from e
return router