gh auth, commit

# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-06-24 16:28:17 -07:00
parent 114946ae88
commit 94282d3482
9 changed files with 962 additions and 43 deletions

View file

@ -10,13 +10,19 @@ from abc import ABC, abstractmethod
from asyncio import Lock
from pathlib import Path
from typing import Self
from urllib.parse import parse_qs
from urllib.parse import parse_qs, urlparse
import httpx
from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.distribution.datatypes import (
AuthenticationConfig,
CustomAuthConfig,
GitHubAuthConfig,
OAuth2TokenAuthConfig,
User,
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -62,6 +68,24 @@ class AuthProvider(ABC):
"""Clean up any resources."""
pass
def setup_routes(self, app):
"""Setup any provider-specific routes (e.g., OAuth callbacks).
This is optional - providers that don't need special routes can skip this.
"""
return
def get_public_paths(self) -> list[str]:
"""Return a list of path prefixes that should bypass authentication.
This is optional - providers that don't have public paths return empty list.
"""
return []
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return provider-specific authentication error message."""
return "Authentication required"
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
@ -232,6 +256,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
async def close(self):
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return OAuth2-specific authentication error message."""
if self.config.issuer:
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
elif self.config.introspection:
# Extract domain from introspection URL for a cleaner message
domain = urlparse(self.config.introspection.url).netloc
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
@ -338,15 +373,25 @@ class CustomAuthProvider(AuthProvider):
await self._client.aclose()
self._client = None
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return custom auth provider-specific authentication error message."""
# Extract domain from endpoint URL for a cleaner message
domain = urlparse(self.config.endpoint).netloc
if domain:
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
else:
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
if provider_type == "custom":
if isinstance(config, CustomAuthConfig):
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
elif isinstance(config, OAuth2TokenAuthConfig):
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
elif isinstance(config, GitHubAuthConfig):
from .github_oauth_auth_provider import GitHubAuthProvider
return GitHubAuthProvider(config)
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
raise ValueError(f"Unknown authentication config type: {type(config)}")