add test for streaming, test against server

This commit is contained in:
Ashwin Bharambe 2025-05-22 16:13:07 -07:00
parent 0d67e17a91
commit 9352d9b42c
3 changed files with 37 additions and 3 deletions

View file

@ -76,8 +76,8 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
async def get_tool(self, tool_name: str) -> Tool: ...
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
class ListToolGroupsResponse(BaseModel):

View file

@ -28,7 +28,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
@ -122,6 +122,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
else:
return HTTPException(
status_code=500,

View file

@ -21,6 +21,9 @@ from starlette.exceptions import HTTPException
from starlette.responses import Response
from starlette.routing import Mount, Route
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.distribution.datatypes import AuthenticationRequiredError
AUTH_TOKEN = "test-token"
@ -120,6 +123,13 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
try:
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
except Exception as e:
# An error is OK since the toolgroup may not exist
print(f"Error unregistering toolgroup: {e}")
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
@ -187,3 +197,25 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
third = steps[2]
assert third.step_type == "inference"
assert len(third.api_model_response.tool_calls) == 0
# when streaming, we currently don't check auth headers upfront and fail the request
# early. but we should at least be generating a 401 later in the process.
response = agent.create_turn(
session_id=session_id,
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
}
],
stream=True,
)
if isinstance(llama_stack_client, LlamaStackAsLibraryClient):
with pytest.raises(AuthenticationRequiredError):
for _ in response:
pass
else:
error_chunks = [chunk for chunk in response if "error" in chunk.model_dump()]
assert len(error_chunks) == 1
chunk = error_chunks[0].model_dump()
assert "Unauthorized" in chunk["error"]["message"]