mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-22 16:23:08 +00:00
Merge branch 'main' into responses-and-safety
This commit is contained in:
commit
90ee3001d9
33 changed files with 16970 additions and 713 deletions
|
@ -175,6 +175,9 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
tools=None,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Should have content part events for text streaming
|
||||
|
|
|
@ -213,7 +213,6 @@ class TestReferenceBatchesImpl:
|
|||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
|
@ -765,3 +764,12 @@ class TestReferenceBatchesImpl:
|
|||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
||||
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
||||
|
||||
async def test_create_batch_embeddings_endpoint(self, provider):
|
||||
"""Test that batch creation succeeds with embeddings endpoint."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/embeddings",
|
||||
completion_window="24h",
|
||||
)
|
||||
assert batch.endpoint == "/v1/embeddings"
|
||||
|
|
|
@ -122,7 +122,7 @@ def mock_impls():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def scope_middleware_with_mocks(mock_auth_endpoint):
|
||||
def middleware_with_mocks(mock_auth_endpoint):
|
||||
"""Create AuthenticationMiddleware with mocked route implementations"""
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
|
@ -137,18 +137,20 @@ def scope_middleware_with_mocks(mock_auth_endpoint):
|
|||
# Mock the route_impls to simulate finding routes with required scopes
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
|
||||
|
||||
public_webmethod = WebMethod(route="/test/public", method="GET")
|
||||
routes = {
|
||||
("POST", "/test/scoped"): WebMethod(route="/test/scoped", method="POST", required_scope="test.read"),
|
||||
("GET", "/test/public"): WebMethod(route="/test/public", method="GET"),
|
||||
("GET", "/health"): WebMethod(route="/health", method="GET", require_authentication=False),
|
||||
("GET", "/version"): WebMethod(route="/version", method="GET", require_authentication=False),
|
||||
("GET", "/models/list"): WebMethod(route="/models/list", method="GET", require_authentication=True),
|
||||
}
|
||||
|
||||
# Mock the route finding logic
|
||||
def mock_find_matching_route(method, path, route_impls):
|
||||
if method == "POST" and path == "/test/scoped":
|
||||
return None, {}, "/test/scoped", scoped_webmethod
|
||||
elif method == "GET" and path == "/test/public":
|
||||
return None, {}, "/test/public", public_webmethod
|
||||
else:
|
||||
raise ValueError("No matching route")
|
||||
webmethod = routes.get((method, path))
|
||||
if webmethod:
|
||||
return None, {}, path, webmethod
|
||||
raise ValueError("No matching route")
|
||||
|
||||
import llama_stack.core.server.auth
|
||||
|
||||
|
@ -659,9 +661,9 @@ def test_valid_introspection_with_custom_mapping_authentication(
|
|||
|
||||
# Scope-based authorization tests
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
|
||||
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
|
||||
async def test_scope_authorization_success(middleware_with_mocks, valid_api_key):
|
||||
"""Test that user with required scope can access protected endpoint"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
@ -680,9 +682,9 @@ async def test_scope_authorization_success(scope_middleware_with_mocks, valid_ap
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
|
||||
async def test_scope_authorization_denied(middleware_with_mocks, valid_api_key):
|
||||
"""Test that user without required scope gets 403 access denied"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
@ -710,9 +712,9 @@ async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api
|
|||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
|
||||
async def test_public_endpoint_no_scope_required(middleware_with_mocks, valid_api_key):
|
||||
"""Test that public endpoints work without specific scopes"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
@ -730,9 +732,9 @@ async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, va
|
|||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
|
||||
async def test_scope_authorization_no_auth_disabled(middleware_with_mocks):
|
||||
"""Test that when auth is disabled (no user), scope checks are bypassed"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
|
@ -907,3 +909,41 @@ def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mo
|
|||
request_body = call_args[1]["json"]
|
||||
assert request_body["apiVersion"] == "authentication.k8s.io/v1"
|
||||
assert request_body["kind"] == "SelfSubjectReview"
|
||||
|
||||
|
||||
async def test_unauthenticated_endpoint_access_health(middleware_with_mocks):
|
||||
"""Test that /health endpoints can be accessed without authentication"""
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
|
||||
# Test request to /health without auth header (level prefix v1 is added by router)
|
||||
scope = {"type": "http", "path": "/health", "headers": [], "method": "GET"}
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
# Should allow the request to proceed without authentication
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Verify that the request was passed to the app
|
||||
mock_app.assert_called_once_with(scope, receive, send)
|
||||
|
||||
# Verify that no error response was sent
|
||||
assert not any(call[0][0].get("status") == 401 for call in send.call_args_list)
|
||||
|
||||
|
||||
async def test_unauthenticated_endpoint_denied_for_other_paths(middleware_with_mocks):
|
||||
"""Test that endpoints other than /health and /version require authentication"""
|
||||
middleware, mock_app = middleware_with_mocks
|
||||
|
||||
# Test request to /models/list without auth header
|
||||
scope = {"type": "http", "path": "/models/list", "headers": [], "method": "GET"}
|
||||
receive = AsyncMock()
|
||||
send = AsyncMock()
|
||||
|
||||
# Should return 401 error
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Verify that the app was NOT called
|
||||
mock_app.assert_not_called()
|
||||
|
||||
# Verify that a 401 error response was sent
|
||||
assert any(call[0][0].get("status") == 401 for call in send.call_args_list)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue