more fixes

This commit is contained in:
Ashwin Bharambe 2025-05-18 07:40:11 -07:00
parent b5d5d1fba0
commit cc77a1b4c8
2 changed files with 152 additions and 36 deletions

View file

@ -12,7 +12,7 @@ from urllib.parse import parse_qs
import httpx import httpx
from jose import jwt from jose import jwt
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -77,6 +77,7 @@ class AuthProviderType(str, Enum):
KUBERNETES = "kubernetes" KUBERNETES = "kubernetes"
CUSTOM = "custom" CUSTOM = "custom"
OAUTH2_TOKEN = "oauth2_token"
class AuthProviderConfig(BaseModel): class AuthProviderConfig(BaseModel):
@ -177,24 +178,61 @@ class KubernetesAuthProvider(AuthProvider):
self._client = None self._client = None
JWT_AUDIENCE = "llama-stack" def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes()
for claim_key, attribute_key in mapping.items():
if claim_key not in claims or not hasattr(attributes, attribute_key):
continue
claim = claims[claim_key]
if isinstance(claim, list):
values = claim
else:
values = claim.split()
current = getattr(attributes, attribute_key)
if current:
current.extend(values)
else:
setattr(attributes, attribute_key, values)
return attributes
class JWTAuthProviderConfig(BaseModel): class OAuth2TokenAuthProviderConfig(BaseModel):
# The JWKS URI for collecting public keys # The JWKS URI for collecting public keys
jwks_uri: str jwks_uri: str
algorithm: str = "RS256"
cache_ttl: int = 3600 cache_ttl: int = 3600
audience: str = "llama-stack"
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
if value not in AccessAttributes.model_fields:
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v
class JWTAuthProvider(AuthProvider): class OAuth2TokenAuthProvider(AuthProvider):
""" """
JWT token authentication provider that validates a JWT token and extracts access attributes. JWT token authentication provider that validates a JWT token and extracts access attributes.
This should be the standard authentication provider for most use cases. This should be the standard authentication provider for most use cases.
""" """
def __init__(self, config: JWTAuthProviderConfig): def __init__(self, config: OAuth2TokenAuthProviderConfig):
self.config = config self.config = config
self._jwks_at: float = 0.0 self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {} self._jwks: dict[str, str] = {}
@ -204,15 +242,17 @@ class JWTAuthProvider(AuthProvider):
await self._refresh_jwks() await self._refresh_jwks()
try: try:
kid = jwt.get_unverified_header(token)["kid"] header = jwt.get_unverified_header(token)
kid = header["kid"]
if kid not in self._jwks: if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}") raise ValueError(f"Unknown key ID: {kid}")
key = self._jwks[kid] key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
claims = jwt.decode( claims = jwt.decode(
token, token,
key, key_data,
algorithms=[self.config.algorithm], algorithms=[algorithm],
audience=JWT_AUDIENCE, audience=self.config.audience,
options={"verify_exp": True}, options={"verify_exp": True},
) )
except Exception as exc: except Exception as exc:
@ -220,29 +260,11 @@ class JWTAuthProvider(AuthProvider):
# There are other standard claims, the most relevant of which is `scope`. # There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes. # We should incorporate these into the access attributes.
principal = f"{claims['iss']}:{claims['sub']}" principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
teams = claims.get("teams", [])
if not teams:
if team := claims.get("team", claims.get("team_id")):
teams = [team]
projects = claims.get("projects", [])
if not projects:
if project := claims.get("project", claims.get("project_id")):
projects = [project]
namespaces = claims.get("namespaces", [])
if not namespaces:
if namespace := claims.get("namespace", claims.get("tenant")):
namespaces = [namespace]
return TokenValidationResult( return TokenValidationResult(
principal=principal, principal=principal,
access_attributes=AccessAttributes( access_attributes=access_attributes,
roles=claims.get("groups", claims.get("roles", [])), # Okta / Auth0
teams=teams,
projects=projects,
namespaces=namespaces,
),
) )
async def close(self): async def close(self):
@ -250,10 +272,15 @@ class JWTAuthProvider(AuthProvider):
async def _refresh_jwks(self) -> None: async def _refresh_jwks(self) -> None:
if time.time() - self._jwks_at > self.config.cache_ttl: if time.time() - self._jwks_at > self.config.cache_ttl:
with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
res = await client.get(self.config.jwks_uri, timeout=5) res = await client.get(self.config.jwks_uri, timeout=5)
res.raise_for_status() res.raise_for_status()
self._jwks = {k["kid"]: k for k in res.json()["keys"]} jwks_data = res.json()["keys"]
self._jwks = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
self._jwks[kid] = k
self._jwks_at = time.time() self._jwks_at = time.time()
@ -340,8 +367,8 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config)) return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
elif provider_type == "custom": elif provider_type == "custom":
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "jwt": elif provider_type == "oauth2_token":
return JWTAuthProvider(JWTAuthProviderConfig.model_validate(config.config)) return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
else: else:
supported_providers = ", ".join([t.value for t in AuthProviderType]) supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}") raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")

View file

@ -27,6 +27,10 @@ class MockResponse:
def json(self): def json(self):
return self._json_data return self._json_data
def raise_for_status(self):
if self.status_code != 200:
raise Exception(f"HTTP error: {self.status_code}")
@pytest.fixture @pytest.fixture
def mock_auth_endpoint(): def mock_auth_endpoint():
@ -379,3 +383,88 @@ async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope):
assert attributes["roles"] == ["admin"] assert attributes["roles"] == ["admin"]
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send) mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
# oauth2 token provider tests
@pytest.fixture
def oauth2_app():
app = FastAPI()
auth_config = AuthProviderConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks_uri": "http://mock-authz-service/token/introspect",
"cache_ttl": "3600",
"audience": "llama-stack",
},
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def oauth2_client(oauth2_app):
return TestClient(oauth2_app)
def test_missing_auth_header_oauth2(oauth2_client):
response = oauth2_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
def test_invalid_auth_header_format_oauth2(oauth2_client):
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
async def mock_jwks_response(*args, **kwargs):
return MockResponse(
200,
{
"keys": [
{
"kid": "1234567890",
"kty": "oct",
"alg": "HS256",
"use": "sig",
"k": "MTIzNDU2Nzg5MA", # Base64-encoded "1234567890"
}
]
},
)
@pytest.fixture
def jwt_token_valid():
from jose import jwt
# correctly signed jwt token with "kid" in header
return jwt.encode(
{
"sub": "my-user",
"groups": ["group1", "group2"],
"scope": "foo bar",
"aud": "llama-stack",
},
key="1234567890",
algorithm="HS256",
headers={"kid": "1234567890"},
)
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
# TODO: add more tests for oauth2 token provider