mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-24 16:57:21 +00:00 
			
		
		
		
	# What does this PR do? This PR adds GitHub OAuth authentication support to Llama Stack, allowing users to authenticate using their GitHub credentials (#2508) . 1. support verifying github acesss tokens 2. support provider-specific auth error messages 3. opportunistic reorganized the auth configs for better ergonomics ## Test Plan Added unit tests. Also tested e2e manually: ``` server: port: 8321 auth: provider_config: type: github_token ``` ``` ~/projects/llama-stack/llama_stack/ui ❯ curl -v http://localhost:8321/v1/models * Host localhost:8321 was resolved. * IPv6: ::1 * IPv4: 127.0.0.1 * Trying [::1]:8321... * Connected to localhost (::1) port 8321 > GET /v1/models HTTP/1.1 > Host: localhost:8321 > User-Agent: curl/8.7.1 > Accept: */* > * Request completely sent off < HTTP/1.1 401 Unauthorized < date: Fri, 27 Jun 2025 21:51:25 GMT < server: uvicorn < content-type: application/json < x-trace-id: 5390c6c0654086c55d87c86d7cbf2f6a < Transfer-Encoding: chunked < * Connection #0 to host localhost left intact {"error": {"message": "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"}} ~/projects/llama-stack/llama_stack/ui ❯ ./scripts/unit-tests.sh ~/projects/llama-stack/llama_stack/ui ❯ curl "http://localhost:8321/v1/models" \ -H "Authorization: Bearer <token_obtained_from_github>" \ {"data":[{"identifier":"accounts/fireworks/models/llama-guard-3-11b-vision","provider_resource_id":"accounts/fireworks/models/llama-guard-3-11b-vision","provider_id":"fireworks","type":"model","metadata":{},"model_type":"llm"},{"identifier":"accounts/fireworks/models/llama-guard-3-8b","provider_resource_id":"accounts/fireworks/models/llama-guard-3-8b","provider_id":"fireworks","type":"model","metadata":{},"model_type":"llm"},{"identifier":"accounts/fireworks/models/llama-v3p1-405b-instruct","provider_resource_id":"accounts/f ``` --------- Co-authored-by: Claude <noreply@anthropic.com>
		
			
				
	
	
		
			388 lines
		
	
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			388 lines
		
	
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import ssl
 | |
| import time
 | |
| from abc import ABC, abstractmethod
 | |
| from asyncio import Lock
 | |
| from urllib.parse import parse_qs, urlparse
 | |
| 
 | |
| import httpx
 | |
| from jose import jwt
 | |
| from pydantic import BaseModel, Field
 | |
| 
 | |
| from llama_stack.distribution.datatypes import (
 | |
|     AuthenticationConfig,
 | |
|     CustomAuthConfig,
 | |
|     GitHubTokenAuthConfig,
 | |
|     OAuth2TokenAuthConfig,
 | |
|     User,
 | |
| )
 | |
| from llama_stack.log import get_logger
 | |
| 
 | |
| logger = get_logger(name=__name__, category="auth")
 | |
| 
 | |
| 
 | |
| class AuthResponse(BaseModel):
 | |
|     """The format of the authentication response from the auth endpoint."""
 | |
| 
 | |
|     principal: str
 | |
|     # further attributes that may be used for access control decisions
 | |
|     attributes: dict[str, list[str]] | None = None
 | |
|     message: str | None = Field(
 | |
|         default=None, description="Optional message providing additional context about the authentication result."
 | |
|     )
 | |
| 
 | |
| 
 | |
| class AuthRequestContext(BaseModel):
 | |
|     path: str = Field(description="The path of the request being authenticated")
 | |
| 
 | |
|     headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
 | |
| 
 | |
|     params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
 | |
| 
 | |
| 
 | |
| class AuthRequest(BaseModel):
 | |
|     api_key: str = Field(description="The API key extracted from the Authorization header")
 | |
| 
 | |
|     request: AuthRequestContext = Field(description="Context information about the request being authenticated")
 | |
| 
 | |
| 
 | |
| class AuthProvider(ABC):
 | |
|     """Abstract base class for authentication providers."""
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def validate_token(self, token: str, scope: dict | None = None) -> User:
 | |
|         """Validate a token and return access attributes."""
 | |
|         pass
 | |
| 
 | |
|     @abstractmethod
 | |
|     async def close(self):
 | |
|         """Clean up any resources."""
 | |
|         pass
 | |
| 
 | |
|     def get_auth_error_message(self, scope: dict | None = None) -> str:
 | |
|         """Return provider-specific authentication error message."""
 | |
|         return "Authentication required"
 | |
| 
 | |
| 
 | |
| def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
 | |
|     attributes: dict[str, list[str]] = {}
 | |
|     for claim_key, attribute_key in mapping.items():
 | |
|         if claim_key not in claims:
 | |
|             continue
 | |
|         claim = claims[claim_key]
 | |
|         if isinstance(claim, list):
 | |
|             values = claim
 | |
|         else:
 | |
|             values = claim.split()
 | |
| 
 | |
|         if attribute_key in attributes:
 | |
|             attributes[attribute_key].extend(values)
 | |
|         else:
 | |
|             attributes[attribute_key] = values
 | |
|     return attributes
 | |
| 
 | |
| 
 | |
| class OAuth2TokenAuthProvider(AuthProvider):
 | |
|     """
 | |
|     JWT token authentication provider that validates a JWT token and extracts access attributes.
 | |
| 
 | |
|     This should be the standard authentication provider for most use cases.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, config: OAuth2TokenAuthConfig):
 | |
|         self.config = config
 | |
|         self._jwks_at: float = 0.0
 | |
|         self._jwks: dict[str, str] = {}
 | |
|         self._jwks_lock = Lock()
 | |
| 
 | |
|     async def validate_token(self, token: str, scope: dict | None = None) -> User:
 | |
|         if self.config.jwks:
 | |
|             return await self.validate_jwt_token(token, scope)
 | |
|         if self.config.introspection:
 | |
|             return await self.introspect_token(token, scope)
 | |
|         raise ValueError("One of jwks or introspection must be configured")
 | |
| 
 | |
|     async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
 | |
|         """Validate a token using the JWT token."""
 | |
|         await self._refresh_jwks()
 | |
| 
 | |
|         try:
 | |
|             header = jwt.get_unverified_header(token)
 | |
|             kid = header["kid"]
 | |
|             if kid not in self._jwks:
 | |
|                 raise ValueError(f"Unknown key ID: {kid}")
 | |
|             key_data = self._jwks[kid]
 | |
|             algorithm = header.get("alg", "RS256")
 | |
|             claims = jwt.decode(
 | |
|                 token,
 | |
|                 key_data,
 | |
|                 algorithms=[algorithm],
 | |
|                 audience=self.config.audience,
 | |
|                 issuer=self.config.issuer,
 | |
|             )
 | |
|         except Exception as exc:
 | |
|             raise ValueError("Invalid JWT token") from exc
 | |
| 
 | |
|         # There are other standard claims, the most relevant of which is `scope`.
 | |
|         # We should incorporate these into the access attributes.
 | |
|         principal = claims["sub"]
 | |
|         access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
 | |
|         return User(
 | |
|             principal=principal,
 | |
|             attributes=access_attributes,
 | |
|         )
 | |
| 
 | |
|     async def introspect_token(self, token: str, scope: dict | None = None) -> User:
 | |
|         """Validate a token using token introspection as defined by RFC 7662."""
 | |
|         form = {
 | |
|             "token": token,
 | |
|         }
 | |
|         if self.config.introspection is None:
 | |
|             raise ValueError("Introspection is not configured")
 | |
| 
 | |
|         if self.config.introspection.send_secret_in_body:
 | |
|             form["client_id"] = self.config.introspection.client_id
 | |
|             form["client_secret"] = self.config.introspection.client_secret
 | |
|             auth = None
 | |
|         else:
 | |
|             auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
 | |
|         ssl_ctxt = None
 | |
|         if self.config.tls_cafile:
 | |
|             ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
 | |
|         try:
 | |
|             async with httpx.AsyncClient(verify=ssl_ctxt) as client:
 | |
|                 response = await client.post(
 | |
|                     self.config.introspection.url,
 | |
|                     data=form,
 | |
|                     auth=auth,
 | |
|                     timeout=10.0,  # Add a reasonable timeout
 | |
|                 )
 | |
|                 if response.status_code != 200:
 | |
|                     logger.warning(f"Token introspection failed with status code: {response.status_code}")
 | |
|                     raise ValueError(f"Token introspection failed: {response.status_code}")
 | |
| 
 | |
|                 fields = response.json()
 | |
|                 if not fields["active"]:
 | |
|                     raise ValueError("Token not active")
 | |
|                 principal = fields["sub"] or fields["username"]
 | |
|                 access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
 | |
|                 return User(
 | |
|                     principal=principal,
 | |
|                     attributes=access_attributes,
 | |
|                 )
 | |
|         except httpx.TimeoutException:
 | |
|             logger.exception("Token introspection request timed out")
 | |
|             raise
 | |
|         except ValueError:
 | |
|             # Re-raise ValueError exceptions to preserve their message
 | |
|             raise
 | |
|         except Exception as e:
 | |
|             logger.exception("Error during token introspection")
 | |
|             raise ValueError("Token introspection error") from e
 | |
| 
 | |
|     async def close(self):
 | |
|         pass
 | |
| 
 | |
|     def get_auth_error_message(self, scope: dict | None = None) -> str:
 | |
|         """Return OAuth2-specific authentication error message."""
 | |
|         if self.config.issuer:
 | |
|             return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
 | |
|         elif self.config.introspection:
 | |
|             # Extract domain from introspection URL for a cleaner message
 | |
|             domain = urlparse(self.config.introspection.url).netloc
 | |
|             return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
 | |
|         else:
 | |
|             return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
 | |
| 
 | |
|     async def _refresh_jwks(self) -> None:
 | |
|         """
 | |
|         Refresh the JWKS cache.
 | |
| 
 | |
|         This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
 | |
|         If the cache is expired, we refresh the JWKS from the JWKS URI.
 | |
| 
 | |
|         Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
 | |
|             * It doesn't have user authentication flows
 | |
|             * It doesn't have refresh tokens
 | |
|         """
 | |
|         async with self._jwks_lock:
 | |
|             if self.config.jwks is None:
 | |
|                 raise ValueError("JWKS is not configured")
 | |
|             if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
 | |
|                 headers = {}
 | |
|                 if self.config.jwks.token:
 | |
|                     headers["Authorization"] = f"Bearer {self.config.jwks.token}"
 | |
|                 verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
 | |
|                 async with httpx.AsyncClient(verify=verify) as client:
 | |
|                     res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
 | |
|                     res.raise_for_status()
 | |
|                     jwks_data = res.json()["keys"]
 | |
|                     updated = {}
 | |
|                     for k in jwks_data:
 | |
|                         kid = k["kid"]
 | |
|                         # Store the entire key object as it may be needed for different algorithms
 | |
|                         updated[kid] = k
 | |
|                     self._jwks = updated
 | |
|                     self._jwks_at = time.time()
 | |
| 
 | |
| 
 | |
| class CustomAuthProvider(AuthProvider):
 | |
|     """Custom authentication provider that uses an external endpoint."""
 | |
| 
 | |
|     def __init__(self, config: CustomAuthConfig):
 | |
|         self.config = config
 | |
|         self._client = None
 | |
| 
 | |
|     async def validate_token(self, token: str, scope: dict | None = None) -> User:
 | |
|         """Validate a token using the custom authentication endpoint."""
 | |
|         if scope is None:
 | |
|             scope = {}
 | |
| 
 | |
|         headers = dict(scope.get("headers", []))
 | |
|         path = scope.get("path", "")
 | |
|         request_headers = {k.decode(): v.decode() for k, v in headers.items()}
 | |
| 
 | |
|         # Remove sensitive headers
 | |
|         if "authorization" in request_headers:
 | |
|             del request_headers["authorization"]
 | |
| 
 | |
|         query_string = scope.get("query_string", b"").decode()
 | |
|         params = parse_qs(query_string)
 | |
| 
 | |
|         # Build the auth request model
 | |
|         auth_request = AuthRequest(
 | |
|             api_key=token,
 | |
|             request=AuthRequestContext(
 | |
|                 path=path,
 | |
|                 headers=request_headers,
 | |
|                 params=params,
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|         # Validate with authentication endpoint
 | |
|         try:
 | |
|             async with httpx.AsyncClient() as client:
 | |
|                 response = await client.post(
 | |
|                     self.config.endpoint,
 | |
|                     json=auth_request.model_dump(),
 | |
|                     timeout=10.0,  # Add a reasonable timeout
 | |
|                 )
 | |
|                 if response.status_code != 200:
 | |
|                     logger.warning(f"Authentication failed with status code: {response.status_code}")
 | |
|                     raise ValueError(f"Authentication failed: {response.status_code}")
 | |
| 
 | |
|                 # Parse and validate the auth response
 | |
|                 try:
 | |
|                     response_data = response.json()
 | |
|                     auth_response = AuthResponse(**response_data)
 | |
|                     return User(principal=auth_response.principal, attributes=auth_response.attributes)
 | |
|                 except Exception as e:
 | |
|                     logger.exception("Error parsing authentication response")
 | |
|                     raise ValueError("Invalid authentication response format") from e
 | |
| 
 | |
|         except httpx.TimeoutException:
 | |
|             logger.exception("Authentication request timed out")
 | |
|             raise
 | |
|         except ValueError:
 | |
|             # Re-raise ValueError exceptions to preserve their message
 | |
|             raise
 | |
|         except Exception as e:
 | |
|             logger.exception("Error during authentication")
 | |
|             raise ValueError("Authentication service error") from e
 | |
| 
 | |
|     async def close(self):
 | |
|         """Close the HTTP client."""
 | |
|         if self._client:
 | |
|             await self._client.aclose()
 | |
|             self._client = None
 | |
| 
 | |
|     def get_auth_error_message(self, scope: dict | None = None) -> str:
 | |
|         """Return custom auth provider-specific authentication error message."""
 | |
|         domain = urlparse(self.config.endpoint).netloc
 | |
|         if domain:
 | |
|             return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
 | |
|         else:
 | |
|             return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
 | |
| 
 | |
| 
 | |
| class GitHubTokenAuthProvider(AuthProvider):
 | |
|     """
 | |
|     GitHub token authentication provider that validates GitHub access tokens directly.
 | |
| 
 | |
|     This provider accepts GitHub personal access tokens or OAuth tokens and verifies
 | |
|     them against the GitHub API to get user information.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, config: GitHubTokenAuthConfig):
 | |
|         self.config = config
 | |
| 
 | |
|     async def validate_token(self, token: str, scope: dict | None = None) -> User:
 | |
|         """Validate a GitHub token by calling the GitHub API.
 | |
| 
 | |
|         This validates tokens issued by GitHub (personal access tokens or OAuth tokens).
 | |
|         """
 | |
|         try:
 | |
|             user_info = await _get_github_user_info(token, self.config.github_api_base_url)
 | |
|         except httpx.HTTPStatusError as e:
 | |
|             logger.warning(f"GitHub token validation failed: {e}")
 | |
|             raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
 | |
| 
 | |
|         principal = user_info["user"]["login"]
 | |
| 
 | |
|         github_data = {
 | |
|             "login": user_info["user"]["login"],
 | |
|             "id": str(user_info["user"]["id"]),
 | |
|             "organizations": user_info.get("organizations", []),
 | |
|         }
 | |
| 
 | |
|         access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
 | |
| 
 | |
|         return User(
 | |
|             principal=principal,
 | |
|             attributes=access_attributes,
 | |
|         )
 | |
| 
 | |
|     async def close(self):
 | |
|         """Clean up any resources."""
 | |
|         pass
 | |
| 
 | |
|     def get_auth_error_message(self, scope: dict | None = None) -> str:
 | |
|         """Return GitHub-specific authentication error message."""
 | |
|         return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
 | |
| 
 | |
| 
 | |
| async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict:
 | |
|     """Fetch user info and organizations from GitHub API."""
 | |
|     headers = {
 | |
|         "Authorization": f"Bearer {access_token}",
 | |
|         "Accept": "application/vnd.github.v3+json",
 | |
|         "User-Agent": "llama-stack",
 | |
|     }
 | |
| 
 | |
|     async with httpx.AsyncClient() as client:
 | |
|         user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0)
 | |
|         user_response.raise_for_status()
 | |
|         user_data = user_response.json()
 | |
| 
 | |
|         return {
 | |
|             "user": user_data,
 | |
|         }
 | |
| 
 | |
| 
 | |
| def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
 | |
|     """Factory function to create the appropriate auth provider."""
 | |
|     provider_config = config.provider_config
 | |
| 
 | |
|     if isinstance(provider_config, CustomAuthConfig):
 | |
|         return CustomAuthProvider(provider_config)
 | |
|     elif isinstance(provider_config, OAuth2TokenAuthConfig):
 | |
|         return OAuth2TokenAuthProvider(provider_config)
 | |
|     elif isinstance(provider_config, GitHubTokenAuthConfig):
 | |
|         return GitHubTokenAuthProvider(provider_config)
 | |
|     else:
 | |
|         raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
 |