From 9352d9b42c024262647189873df5af78b598a0f1 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 22 May 2025 16:13:07 -0700 Subject: [PATCH] add test for streaming, test against server --- llama_stack/apis/tools/tools.py | 4 +-- llama_stack/distribution/server/server.py | 4 ++- tests/integration/tool_runtime/test_mcp.py | 32 ++++++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 2f62b0ba1..29649495c 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -76,8 +76,8 @@ class ToolInvocationResult(BaseModel): class ToolStore(Protocol): - def get_tool(self, tool_name: str) -> Tool: ... - def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ... + async def get_tool(self, tool_name: str) -> Tool: ... + async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ... class ListToolGroupsResponse(BaseModel): diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 7069390cf..d70f06691 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -28,7 +28,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError -from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig +from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import ( PROVIDER_DATA_VAR, @@ -122,6 +122,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") + elif isinstance(exc, AuthenticationRequiredError): + return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}") else: return HTTPException( status_code=500, diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py index 1b003e152..e553c6a0b 100644 --- a/tests/integration/tool_runtime/test_mcp.py +++ b/tests/integration/tool_runtime/test_mcp.py @@ -21,6 +21,9 @@ from starlette.exceptions import HTTPException from starlette.responses import Response from starlette.routing import Mount, Route +from llama_stack import LlamaStackAsLibraryClient +from llama_stack.distribution.datatypes import AuthenticationRequiredError + AUTH_TOKEN = "test-token" @@ -120,6 +123,13 @@ def test_mcp_invocation(llama_stack_client, mcp_server): auth_headers = { "X-LlamaStack-Provider-Data": json.dumps(provider_data), } + + try: + llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers) + except Exception as e: + # An error is OK since the toolgroup may not exist + print(f"Error unregistering toolgroup: {e}") + llama_stack_client.toolgroups.register( toolgroup_id=test_toolgroup_id, provider_id="model-context-protocol", @@ -187,3 +197,25 @@ def test_mcp_invocation(llama_stack_client, mcp_server): third = steps[2] assert third.step_type == "inference" assert len(third.api_model_response.tool_calls) == 0 + + # when streaming, we currently don't check auth headers upfront and fail the request + # early. but we should at least be generating a 401 later in the process. + response = agent.create_turn( + session_id=session_id, + messages=[ + { + "role": "user", + "content": "Yo. Use tools.", + } + ], + stream=True, + ) + if isinstance(llama_stack_client, LlamaStackAsLibraryClient): + with pytest.raises(AuthenticationRequiredError): + for _ in response: + pass + else: + error_chunks = [chunk for chunk in response if "error" in chunk.model_dump()] + assert len(error_chunks) == 1 + chunk = error_chunks[0].model_dump() + assert "Unauthorized" in chunk["error"]["message"]