mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: add additional auth provider that uses oauth token introspection (#2187)
# What does this PR do? This adds an alternative option to the oauth_token auth provider that can be used with existing authorization services which support token introspection as defined in RFC 7662. This could be useful where token revocation needs to be handled or where opaque tokens (or other non jwt formatted tokens) are used ## Test Plan Tested against keycloak Signed-off-by: Gordon Sim <gsim@redhat.com>
This commit is contained in:
parent
87a4b9cb28
commit
091d8c48f2
3 changed files with 251 additions and 13 deletions
|
@ -229,7 +229,7 @@ class AuthenticationConfig(BaseModel):
|
|||
...,
|
||||
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||
)
|
||||
config: dict[str, str] = Field(
|
||||
config: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Provider-specific configuration",
|
||||
)
|
||||
|
|
|
@ -5,15 +5,18 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import ssl
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -85,7 +88,7 @@ class AuthProviderConfig(BaseModel):
|
|||
"""Base configuration for authentication providers."""
|
||||
|
||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||
config: dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||
config: dict[str, Any] = Field(..., description="Provider-specific configuration")
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
|
@ -198,10 +201,21 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
|||
return attributes
|
||||
|
||||
|
||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
jwks_uri: str
|
||||
uri: str
|
||||
cache_ttl: int = 3600
|
||||
|
||||
|
||||
class OAuth2IntrospectionConfig(BaseModel):
|
||||
url: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
send_secret_in_body: bool = False
|
||||
tls_cafile: str | None = None
|
||||
|
||||
|
||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||
audience: str = "llama-stack"
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
|
@ -214,6 +228,8 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
|||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
jwks: OAuth2JWKSConfig | None
|
||||
introspection: OAuth2IntrospectionConfig | None = None
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
|
@ -225,6 +241,14 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
|
|||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode(self) -> Self:
|
||||
if not self.jwks and not self.introspection:
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
if self.jwks and self.introspection:
|
||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||
return self
|
||||
|
||||
|
||||
class OAuth2TokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
|
@ -240,8 +264,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
self._jwks_lock = Lock()
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
if self.config.jwks:
|
||||
return await self.validate_jwt_token(token, self.config.jwks, scope)
|
||||
if self.config.introspection:
|
||||
return await self.introspect_token(token, self.config.introspection, scope)
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
|
||||
async def validate_jwt_token(
|
||||
self, token: str, config: OAuth2JWKSConfig, scope: dict | None = None
|
||||
) -> TokenValidationResult:
|
||||
"""Validate a token using the JWT token."""
|
||||
await self._refresh_jwks()
|
||||
await self._refresh_jwks(config)
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
@ -269,14 +302,61 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
access_attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
async def introspect_token(
|
||||
self, token: str, config: OAuth2IntrospectionConfig, scope: dict | None = None
|
||||
) -> TokenValidationResult:
|
||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||
form = {
|
||||
"token": token,
|
||||
}
|
||||
if config.send_secret_in_body:
|
||||
form["client_id"] = config.client_id
|
||||
form["client_secret"] = config.client_secret
|
||||
auth = None
|
||||
else:
|
||||
auth = (config.client_id, config.client_secret)
|
||||
ssl_ctxt = None
|
||||
if config.tls_cafile:
|
||||
ssl_ctxt = ssl.create_default_context(cafile=config.tls_cafile)
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
||||
response = await client.post(
|
||||
config.url,
|
||||
data=form,
|
||||
auth=auth,
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
fields = response.json()
|
||||
if not fields["active"]:
|
||||
raise ValueError("Token not active")
|
||||
principal = fields["sub"] or fields["username"]
|
||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||
return TokenValidationResult(
|
||||
principal=principal,
|
||||
access_attributes=access_attributes,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Token introspection request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during token introspection")
|
||||
raise ValueError("Token introspection error") from e
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
async def _refresh_jwks(self, config: OAuth2JWKSConfig) -> None:
|
||||
async with self._jwks_lock:
|
||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||
if time.time() - self._jwks_at > config.cache_ttl:
|
||||
async with httpx.AsyncClient() as client:
|
||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||
res = await client.get(config.uri, timeout=5)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
updated = {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue