mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
add test for streaming, test against server
This commit is contained in:
parent
0d67e17a91
commit
9352d9b42c
3 changed files with 37 additions and 3 deletions
|
@ -76,8 +76,8 @@ class ToolInvocationResult(BaseModel):
|
|||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
def get_tool(self, tool_name: str) -> Tool: ...
|
||||
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
class ListToolGroupsResponse(BaseModel):
|
||||
|
|
|
@ -28,7 +28,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
|
@ -122,6 +122,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||
elif isinstance(exc, AuthenticationRequiredError):
|
||||
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
|
||||
else:
|
||||
return HTTPException(
|
||||
status_code=500,
|
||||
|
|
|
@ -21,6 +21,9 @@ from starlette.exceptions import HTTPException
|
|||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
|
||||
AUTH_TOKEN = "test-token"
|
||||
|
||||
|
||||
|
@ -120,6 +123,13 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
|||
auth_headers = {
|
||||
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
||||
}
|
||||
|
||||
try:
|
||||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
|
||||
except Exception as e:
|
||||
# An error is OK since the toolgroup may not exist
|
||||
print(f"Error unregistering toolgroup: {e}")
|
||||
|
||||
llama_stack_client.toolgroups.register(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
provider_id="model-context-protocol",
|
||||
|
@ -187,3 +197,25 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
|||
third = steps[2]
|
||||
assert third.step_type == "inference"
|
||||
assert len(third.api_model_response.tool_calls) == 0
|
||||
|
||||
# when streaming, we currently don't check auth headers upfront and fail the request
|
||||
# early. but we should at least be generating a 401 later in the process.
|
||||
response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Yo. Use tools.",
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
if isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
||||
with pytest.raises(AuthenticationRequiredError):
|
||||
for _ in response:
|
||||
pass
|
||||
else:
|
||||
error_chunks = [chunk for chunk in response if "error" in chunk.model_dump()]
|
||||
assert len(error_chunks) == 1
|
||||
chunk = error_chunks[0].model_dump()
|
||||
assert "Unauthorized" in chunk["error"]["message"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue