Merge remote-tracking branch 'upstream/main' into rm-build

This commit is contained in:
Charlie Doern 2025-12-01 13:52:18 -05:00
commit 39ad54696c
81 changed files with 585 additions and 4236 deletions

View file

@ -155,9 +155,6 @@ def old_config():
provider_type: inline::meta-reference
config: {{}}
api_providers:
telemetry:
provider_type: noop
config: {{}}
"""
)
@ -181,7 +178,7 @@ def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
def test_parse_and_maybe_upgrade_config_old_format(old_config):
result = parse_and_maybe_upgrade_config(old_config)
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
assert all(api in result.providers for api in ["inference", "safety", "memory"])
safety_provider = result.providers["safety"][0]
assert safety_provider.provider_type == "inline::meta-reference"
assert "llama_guard_shield" in safety_provider.config

View file

@ -83,7 +83,7 @@ class TestProviderInitialization:
new_callable=AsyncMock,
):
# Should not raise any exception
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
provider = await get_provider_impl(config, mock_deps, policy=[])
assert provider is not None
async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps):
@ -97,7 +97,7 @@ class TestProviderInitialization:
new_callable=AsyncMock,
):
# Should not raise any exception
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
provider = await get_provider_impl(config, mock_deps, policy=[])
assert provider is not None
assert provider.safety_api is None

View file

@ -364,23 +364,6 @@ def test_invalid_auth_header_format_oauth2(oauth2_client):
assert "Invalid Authorization header format" in response.json()["error"]["message"]
async def mock_jwks_response(*args, **kwargs):
return MockResponse(
200,
{
"keys": [
{
"kid": "1234567890",
"kty": "oct",
"alg": "HS256",
"use": "sig",
"k": base64.b64encode(b"foobarbaz").decode(),
}
]
},
)
@pytest.fixture
def jwt_token_valid():
import jwt
@ -421,28 +404,60 @@ def mock_jwks_urlopen():
yield mock_urlopen
@pytest.fixture
def mock_jwks_urlopen_with_auth_required():
"""Mock urllib.request.urlopen that requires Bearer token for JWKS requests."""
with patch("urllib.request.urlopen") as mock_urlopen:
def side_effect(request, **kwargs):
# Check if Authorization header is present
auth_header = request.headers.get("Authorization") if hasattr(request, "headers") else None
if not auth_header or not auth_header.startswith("Bearer "):
# Simulate 401 Unauthorized
import urllib.error
raise urllib.error.HTTPError(
url=request.full_url if hasattr(request, "full_url") else "",
code=401,
msg="Unauthorized",
hdrs={},
fp=None,
)
# Mock the JWKS response for PyJWKClient
mock_response = Mock()
mock_response.read.return_value = json.dumps(
{
"keys": [
{
"kid": "1234567890",
"kty": "oct",
"alg": "HS256",
"use": "sig",
"k": base64.b64encode(b"foobarbaz").decode(),
}
]
}
).encode()
return mock_response
mock_urlopen.side_effect = side_effect
yield mock_urlopen
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
def test_invalid_oauth2_authentication(oauth2_client, invalid_token, suppress_auth_errors):
def test_invalid_oauth2_authentication(oauth2_client, invalid_token, mock_jwks_urlopen, suppress_auth_errors):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
assert response.status_code == 401
assert "Invalid JWT token" in response.json()["error"]["message"]
async def mock_auth_jwks_response(*args, **kwargs):
if "headers" not in kwargs or "Authorization" not in kwargs["headers"]:
return MockResponse(401, {})
authz = kwargs["headers"]["Authorization"]
if authz != "Bearer my-jwks-token":
return MockResponse(401, {})
return await mock_jwks_response(args, kwargs)
@pytest.fixture
def oauth2_app_with_jwks_token():
app = FastAPI()
@ -472,8 +487,9 @@ def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):
return TestClient(oauth2_app_with_jwks_token)
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid, suppress_auth_errors):
def test_oauth2_with_jwks_token_expected(
oauth2_client, jwt_token_valid, mock_jwks_urlopen_with_auth_required, suppress_auth_errors
):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 401