mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
gh auth, commit
# What does this PR do? ## Test Plan # What does this PR do? ## Test Plan
This commit is contained in:
parent
114946ae88
commit
94282d3482
9 changed files with 962 additions and 43 deletions
109
llama_stack/distribution/server/auth_routes.py
Normal file
109
llama_stack/distribution/server/auth_routes.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
|
||||
from llama_stack.distribution.datatypes import GitHubAuthConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .github_oauth_auth_provider import (
|
||||
GITHUB_CALLBACK_PATH,
|
||||
GITHUB_LOGIN_PATH,
|
||||
GitHubAuthProvider,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="auth_routes")
|
||||
|
||||
|
||||
def create_github_auth_router(config: GitHubAuthConfig) -> APIRouter:
|
||||
"""Create and configure GitHub authentication router."""
|
||||
auth_provider = GitHubAuthProvider(config)
|
||||
router = APIRouter()
|
||||
oauth_states: dict[str, dict] = {}
|
||||
|
||||
def cleanup_expired_states() -> None:
|
||||
"""Remove expired OAuth states."""
|
||||
now = datetime.now(UTC)
|
||||
expired_states = [
|
||||
state
|
||||
for state, data in oauth_states.items()
|
||||
if (now - data["created_at"]).seconds > 300 # 5 minutes
|
||||
]
|
||||
for state in expired_states:
|
||||
oauth_states.pop(state, None)
|
||||
|
||||
@router.get(GITHUB_LOGIN_PATH)
|
||||
async def github_login(request: Request, redirect_url: str | None = None):
|
||||
"""Initiate GitHub OAuth flow."""
|
||||
cleanup_expired_states()
|
||||
|
||||
# Validate redirect URL if provided
|
||||
if redirect_url:
|
||||
if not auth_provider.config.allowed_redirect_urls:
|
||||
raise HTTPException(status_code=400, detail="Redirect URLs not configured")
|
||||
if redirect_url not in auth_provider.config.allowed_redirect_urls:
|
||||
raise HTTPException(status_code=400, detail="Invalid redirect URL")
|
||||
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
oauth_states[state] = {
|
||||
"created_at": datetime.now(UTC),
|
||||
"redirect_url": redirect_url,
|
||||
}
|
||||
|
||||
# Get authorization URL
|
||||
auth_url = auth_provider.get_authorization_url(state, request)
|
||||
|
||||
logger.debug(f"Redirecting to GitHub OAuth: {auth_url}")
|
||||
return RedirectResponse(url=auth_url)
|
||||
|
||||
@router.get(GITHUB_CALLBACK_PATH)
|
||||
async def github_callback(code: str, state: str, request: Request):
|
||||
"""Handle GitHub OAuth callback."""
|
||||
# Validate state parameter
|
||||
if state not in oauth_states:
|
||||
logger.warning(f"Invalid OAuth state received: {state}")
|
||||
raise HTTPException(status_code=400, detail="Invalid state parameter")
|
||||
|
||||
state_data = oauth_states.pop(state)
|
||||
|
||||
# Check state expiry
|
||||
if (datetime.now(UTC) - state_data["created_at"]).seconds > 300:
|
||||
logger.warning("OAuth state expired")
|
||||
raise HTTPException(status_code=400, detail="State expired")
|
||||
|
||||
try:
|
||||
# Complete OAuth flow
|
||||
access_token = await auth_provider.complete_oauth_flow(code, request)
|
||||
|
||||
logger.info("GitHub OAuth successful")
|
||||
|
||||
redirect_url = state_data.get("redirect_url")
|
||||
if redirect_url:
|
||||
# Validate redirect URL from state
|
||||
if redirect_url not in auth_provider.config.allowed_redirect_urls:
|
||||
raise HTTPException(status_code=400, detail="Invalid redirect URL in state")
|
||||
|
||||
import urllib.parse
|
||||
|
||||
params = urllib.parse.urlencode({"token": access_token})
|
||||
return RedirectResponse(url=f"{redirect_url}#{params}")
|
||||
else:
|
||||
# For API calls, only return the access token
|
||||
return JSONResponse(content={"access_token": access_token})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"GitHub OAuth error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error during GitHub OAuth")
|
||||
raise HTTPException(status_code=500, detail="Authentication failed") from e
|
||||
|
||||
return router
|
Loading…
Add table
Add a link
Reference in a new issue