mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
109 lines
4 KiB
Python
109 lines
4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import secrets
|
|
from datetime import UTC, datetime
|
|
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from fastapi.responses import JSONResponse, RedirectResponse
|
|
|
|
from llama_stack.distribution.datatypes import GitHubAuthConfig
|
|
from llama_stack.log import get_logger
|
|
|
|
from .github_oauth_auth_provider import (
|
|
GITHUB_CALLBACK_PATH,
|
|
GITHUB_LOGIN_PATH,
|
|
GitHubAuthProvider,
|
|
)
|
|
|
|
logger = get_logger(name=__name__, category="auth_routes")
|
|
|
|
|
|
def create_github_auth_router(config: GitHubAuthConfig) -> APIRouter:
|
|
"""Create and configure GitHub authentication router."""
|
|
auth_provider = GitHubAuthProvider(config)
|
|
router = APIRouter()
|
|
oauth_states: dict[str, dict] = {}
|
|
|
|
def cleanup_expired_states() -> None:
|
|
"""Remove expired OAuth states."""
|
|
now = datetime.now(UTC)
|
|
expired_states = [
|
|
state
|
|
for state, data in oauth_states.items()
|
|
if (now - data["created_at"]).seconds > 300 # 5 minutes
|
|
]
|
|
for state in expired_states:
|
|
oauth_states.pop(state, None)
|
|
|
|
@router.get(GITHUB_LOGIN_PATH)
|
|
async def github_login(request: Request, redirect_url: str | None = None):
|
|
"""Initiate GitHub OAuth flow."""
|
|
cleanup_expired_states()
|
|
|
|
# Validate redirect URL if provided
|
|
if redirect_url:
|
|
if not auth_provider.config.allowed_redirect_urls:
|
|
raise HTTPException(status_code=400, detail="Redirect URLs not configured")
|
|
if redirect_url not in auth_provider.config.allowed_redirect_urls:
|
|
raise HTTPException(status_code=400, detail="Invalid redirect URL")
|
|
|
|
state = secrets.token_urlsafe(32)
|
|
|
|
oauth_states[state] = {
|
|
"created_at": datetime.now(UTC),
|
|
"redirect_url": redirect_url,
|
|
}
|
|
|
|
# Get authorization URL
|
|
auth_url = auth_provider.get_authorization_url(state, request)
|
|
|
|
logger.debug(f"Redirecting to GitHub OAuth: {auth_url}")
|
|
return RedirectResponse(url=auth_url)
|
|
|
|
@router.get(GITHUB_CALLBACK_PATH)
|
|
async def github_callback(code: str, state: str, request: Request):
|
|
"""Handle GitHub OAuth callback."""
|
|
# Validate state parameter
|
|
if state not in oauth_states:
|
|
logger.warning(f"Invalid OAuth state received: {state}")
|
|
raise HTTPException(status_code=400, detail="Invalid state parameter")
|
|
|
|
state_data = oauth_states.pop(state)
|
|
|
|
# Check state expiry
|
|
if (datetime.now(UTC) - state_data["created_at"]).seconds > 300:
|
|
logger.warning("OAuth state expired")
|
|
raise HTTPException(status_code=400, detail="State expired")
|
|
|
|
try:
|
|
# Complete OAuth flow
|
|
access_token = await auth_provider.complete_oauth_flow(code, request)
|
|
|
|
logger.info("GitHub OAuth successful")
|
|
|
|
redirect_url = state_data.get("redirect_url")
|
|
if redirect_url:
|
|
# Validate redirect URL from state
|
|
if redirect_url not in auth_provider.config.allowed_redirect_urls:
|
|
raise HTTPException(status_code=400, detail="Invalid redirect URL in state")
|
|
|
|
import urllib.parse
|
|
|
|
params = urllib.parse.urlencode({"token": access_token})
|
|
return RedirectResponse(url=f"{redirect_url}#{params}")
|
|
else:
|
|
# For API calls, only return the access token
|
|
return JSONResponse(content={"access_token": access_token})
|
|
|
|
except ValueError as e:
|
|
logger.error(f"GitHub OAuth error: {e}")
|
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
except Exception as e:
|
|
logger.exception("Unexpected error during GitHub OAuth")
|
|
raise HTTPException(status_code=500, detail="Authentication failed") from e
|
|
|
|
return router
|