mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-09 15:17:46 +00:00
gh auth, commit
# What does this PR do? ## Test Plan # What does this PR do? ## Test Plan
This commit is contained in:
parent
114946ae88
commit
94282d3482
9 changed files with 962 additions and 43 deletions
|
@ -10,13 +10,19 @@ from abc import ABC, abstractmethod
|
|||
from asyncio import Lock
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubAuthConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
@ -62,6 +68,24 @@ class AuthProvider(ABC):
|
|||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""Setup any provider-specific routes (e.g., OAuth callbacks).
|
||||
|
||||
This is optional - providers that don't need special routes can skip this.
|
||||
"""
|
||||
return
|
||||
|
||||
def get_public_paths(self) -> list[str]:
|
||||
"""Return a list of path prefixes that should bypass authentication.
|
||||
|
||||
This is optional - providers that don't have public paths return empty list.
|
||||
"""
|
||||
return []
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return provider-specific authentication error message."""
|
||||
return "Authentication required"
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||
attributes: dict[str, list[str]] = {}
|
||||
|
@ -232,6 +256,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
async def close(self):
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return OAuth2-specific authentication error message."""
|
||||
if self.config.issuer:
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
|
||||
elif self.config.introspection:
|
||||
# Extract domain from introspection URL for a cleaner message
|
||||
domain = urlparse(self.config.introspection.url).netloc
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
|
||||
else:
|
||||
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
"""
|
||||
Refresh the JWKS cache.
|
||||
|
@ -338,15 +373,25 @@ class CustomAuthProvider(AuthProvider):
|
|||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return custom auth provider-specific authentication error message."""
|
||||
# Extract domain from endpoint URL for a cleaner message
|
||||
domain = urlparse(self.config.endpoint).netloc
|
||||
if domain:
|
||||
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
|
||||
else:
|
||||
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_type = config.provider_type.lower()
|
||||
|
||||
if provider_type == "custom":
|
||||
if isinstance(config, CustomAuthConfig):
|
||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "oauth2_token":
|
||||
elif isinstance(config, OAuth2TokenAuthConfig):
|
||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||
elif isinstance(config, GitHubAuthConfig):
|
||||
from .github_oauth_auth_provider import GitHubAuthProvider
|
||||
|
||||
return GitHubAuthProvider(config)
|
||||
else:
|
||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||
raise ValueError(f"Unknown authentication config type: {type(config)}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue