mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
more fixes
This commit is contained in:
parent
b5d5d1fba0
commit
cc77a1b4c8
2 changed files with 152 additions and 36 deletions
|
@ -12,7 +12,7 @@ from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -77,6 +77,7 @@ class AuthProviderType(str, Enum):
|
||||||
|
|
||||||
KUBERNETES = "kubernetes"
|
KUBERNETES = "kubernetes"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
OAUTH2_TOKEN = "oauth2_token"
|
||||||
|
|
||||||
|
|
||||||
class AuthProviderConfig(BaseModel):
|
class AuthProviderConfig(BaseModel):
|
||||||
|
@ -177,24 +178,61 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
JWT_AUDIENCE = "llama-stack"
|
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
||||||
|
attributes = AccessAttributes()
|
||||||
|
for claim_key, attribute_key in mapping.items():
|
||||||
|
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
||||||
|
continue
|
||||||
|
claim = claims[claim_key]
|
||||||
|
if isinstance(claim, list):
|
||||||
|
values = claim
|
||||||
|
else:
|
||||||
|
values = claim.split()
|
||||||
|
|
||||||
|
current = getattr(attributes, attribute_key)
|
||||||
|
if current:
|
||||||
|
current.extend(values)
|
||||||
|
else:
|
||||||
|
setattr(attributes, attribute_key, values)
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
class JWTAuthProviderConfig(BaseModel):
|
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
# The JWKS URI for collecting public keys
|
# The JWKS URI for collecting public keys
|
||||||
jwks_uri: str
|
jwks_uri: str
|
||||||
algorithm: str = "RS256"
|
|
||||||
cache_ttl: int = 3600
|
cache_ttl: int = 3600
|
||||||
|
audience: str = "llama-stack"
|
||||||
|
claims_mapping: dict[str, str] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"sub": "roles",
|
||||||
|
"username": "roles",
|
||||||
|
"groups": "teams",
|
||||||
|
"team": "teams",
|
||||||
|
"project": "projects",
|
||||||
|
"tenant": "namespaces",
|
||||||
|
"namespace": "namespaces",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@field_validator("claims_mapping")
|
||||||
|
def validate_claims_mapping(cls, v):
|
||||||
|
for key, value in v.items():
|
||||||
|
if not value:
|
||||||
|
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||||
|
if value not in AccessAttributes.model_fields:
|
||||||
|
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class JWTAuthProvider(AuthProvider):
|
class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
"""
|
"""
|
||||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||||
|
|
||||||
This should be the standard authentication provider for most use cases.
|
This should be the standard authentication provider for most use cases.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: JWTAuthProviderConfig):
|
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._jwks_at: float = 0.0
|
self._jwks_at: float = 0.0
|
||||||
self._jwks: dict[str, str] = {}
|
self._jwks: dict[str, str] = {}
|
||||||
|
@ -204,15 +242,17 @@ class JWTAuthProvider(AuthProvider):
|
||||||
await self._refresh_jwks()
|
await self._refresh_jwks()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kid = jwt.get_unverified_header(token)["kid"]
|
header = jwt.get_unverified_header(token)
|
||||||
|
kid = header["kid"]
|
||||||
if kid not in self._jwks:
|
if kid not in self._jwks:
|
||||||
raise ValueError(f"Unknown key ID: {kid}")
|
raise ValueError(f"Unknown key ID: {kid}")
|
||||||
key = self._jwks[kid]
|
key_data = self._jwks[kid]
|
||||||
|
algorithm = header.get("alg", "RS256")
|
||||||
claims = jwt.decode(
|
claims = jwt.decode(
|
||||||
token,
|
token,
|
||||||
key,
|
key_data,
|
||||||
algorithms=[self.config.algorithm],
|
algorithms=[algorithm],
|
||||||
audience=JWT_AUDIENCE,
|
audience=self.config.audience,
|
||||||
options={"verify_exp": True},
|
options={"verify_exp": True},
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
@ -220,29 +260,11 @@ class JWTAuthProvider(AuthProvider):
|
||||||
|
|
||||||
# There are other standard claims, the most relevant of which is `scope`.
|
# There are other standard claims, the most relevant of which is `scope`.
|
||||||
# We should incorporate these into the access attributes.
|
# We should incorporate these into the access attributes.
|
||||||
principal = f"{claims['iss']}:{claims['sub']}"
|
principal = claims["sub"]
|
||||||
|
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||||
teams = claims.get("teams", [])
|
|
||||||
if not teams:
|
|
||||||
if team := claims.get("team", claims.get("team_id")):
|
|
||||||
teams = [team]
|
|
||||||
projects = claims.get("projects", [])
|
|
||||||
if not projects:
|
|
||||||
if project := claims.get("project", claims.get("project_id")):
|
|
||||||
projects = [project]
|
|
||||||
namespaces = claims.get("namespaces", [])
|
|
||||||
if not namespaces:
|
|
||||||
if namespace := claims.get("namespace", claims.get("tenant")):
|
|
||||||
namespaces = [namespace]
|
|
||||||
|
|
||||||
return TokenValidationResult(
|
return TokenValidationResult(
|
||||||
principal=principal,
|
principal=principal,
|
||||||
access_attributes=AccessAttributes(
|
access_attributes=access_attributes,
|
||||||
roles=claims.get("groups", claims.get("roles", [])), # Okta / Auth0
|
|
||||||
teams=teams,
|
|
||||||
projects=projects,
|
|
||||||
namespaces=namespaces,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
|
@ -250,10 +272,15 @@ class JWTAuthProvider(AuthProvider):
|
||||||
|
|
||||||
async def _refresh_jwks(self) -> None:
|
async def _refresh_jwks(self) -> None:
|
||||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||||
with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
self._jwks = {k["kid"]: k for k in res.json()["keys"]}
|
jwks_data = res.json()["keys"]
|
||||||
|
self._jwks = {}
|
||||||
|
for k in jwks_data:
|
||||||
|
kid = k["kid"]
|
||||||
|
# Store the entire key object as it may be needed for different algorithms
|
||||||
|
self._jwks[kid] = k
|
||||||
self._jwks_at = time.time()
|
self._jwks_at = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
@ -340,8 +367,8 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
||||||
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
|
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
|
||||||
elif provider_type == "custom":
|
elif provider_type == "custom":
|
||||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||||
elif provider_type == "jwt":
|
elif provider_type == "oauth2_token":
|
||||||
return JWTAuthProvider(JWTAuthProviderConfig.model_validate(config.config))
|
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||||
else:
|
else:
|
||||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||||
|
|
|
@ -27,6 +27,10 @@ class MockResponse:
|
||||||
def json(self):
|
def json(self):
|
||||||
return self._json_data
|
return self._json_data
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
if self.status_code != 200:
|
||||||
|
raise Exception(f"HTTP error: {self.status_code}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_auth_endpoint():
|
def mock_auth_endpoint():
|
||||||
|
@ -379,3 +383,88 @@ async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope):
|
||||||
assert attributes["roles"] == ["admin"]
|
assert attributes["roles"] == ["admin"]
|
||||||
|
|
||||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||||
|
|
||||||
|
|
||||||
|
# oauth2 token provider tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth2_app():
|
||||||
|
app = FastAPI()
|
||||||
|
auth_config = AuthProviderConfig(
|
||||||
|
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
|
config={
|
||||||
|
"jwks_uri": "http://mock-authz-service/token/introspect",
|
||||||
|
"cache_ttl": "3600",
|
||||||
|
"audience": "llama-stack",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth2_client(oauth2_app):
|
||||||
|
return TestClient(oauth2_app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_auth_header_oauth2(oauth2_client):
|
||||||
|
response = oauth2_client.get("/test")
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||||
|
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_jwks_response(*args, **kwargs):
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"keys": [
|
||||||
|
{
|
||||||
|
"kid": "1234567890",
|
||||||
|
"kty": "oct",
|
||||||
|
"alg": "HS256",
|
||||||
|
"use": "sig",
|
||||||
|
"k": "MTIzNDU2Nzg5MA", # Base64-encoded "1234567890"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def jwt_token_valid():
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
# correctly signed jwt token with "kid" in header
|
||||||
|
return jwt.encode(
|
||||||
|
{
|
||||||
|
"sub": "my-user",
|
||||||
|
"groups": ["group1", "group2"],
|
||||||
|
"scope": "foo bar",
|
||||||
|
"aud": "llama-stack",
|
||||||
|
},
|
||||||
|
key="1234567890",
|
||||||
|
algorithm="HS256",
|
||||||
|
headers={"kid": "1234567890"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||||
|
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
|
||||||
|
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add more tests for oauth2 token provider
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue