mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
fixed test_chat_agent
Signed-off-by: Daniele Martinoli <dmartino@redhat.com>
This commit is contained in:
parent
5ca575eefe
commit
1181754c5b
2 changed files with 56 additions and 52 deletions
|
@ -93,14 +93,7 @@ agent_config = AgentConfig(
|
|||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {
|
||||
# 'documents_db_id' holds the ID of the registered vector database
|
||||
# where the provided documents will be ingested. This argument is mandatory
|
||||
# when the 'documents' parameter is provided in a 'create_turn' invocation.
|
||||
# When provided, 'documents_db_id' will also be used to extract contextual information
|
||||
# for the query.
|
||||
"documents_db_id": vector_db_id,
|
||||
# Optionally, the 'vector_db_ids' argument can specify additional vector databases
|
||||
# to use at query time.
|
||||
"vector_db_ids": vector_db_id,
|
||||
},
|
||||
}
|
||||
],
|
||||
|
|
|
@ -16,10 +16,11 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseTurnCompletePayload,
|
||||
StepType,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.content_types import URL, TextDelta
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
LogProbConfig,
|
||||
|
@ -27,12 +28,15 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolGroupsResponse,
|
||||
ListToolsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
|
@ -40,7 +44,7 @@ from llama_stack.apis.tools import (
|
|||
ToolInvocationResult,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
|
@ -54,36 +58,37 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|||
class MockInferenceAPI:
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
async def stream_response():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="start",
|
||||
delta="",
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="progress",
|
||||
delta="AI is a fascinating field...",
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text="AI is a fascinating field..."),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="complete",
|
||||
delta="",
|
||||
stop_reason="end_of_turn",
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -133,12 +138,13 @@ class MockToolGroupsAPI:
|
|||
provider_resource_id=toolgroup_id,
|
||||
)
|
||||
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
return []
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=[])
|
||||
|
||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
if tool_group_id == MEMORY_TOOLGROUP:
|
||||
return [
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||
if toolgroup_id == MEMORY_TOOLGROUP:
|
||||
return ListToolsResponse(
|
||||
data=[
|
||||
Tool(
|
||||
identifier=MEMORY_QUERY_TOOL,
|
||||
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||
|
@ -149,8 +155,10 @@ class MockToolGroupsAPI:
|
|||
parameters=[],
|
||||
)
|
||||
]
|
||||
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
|
||||
return [
|
||||
)
|
||||
if toolgroup_id == CODE_INTERPRETER_TOOLGROUP:
|
||||
return ListToolsResponse(
|
||||
data=[
|
||||
Tool(
|
||||
identifier="code_interpreter",
|
||||
provider_resource_id="code_interpreter",
|
||||
|
@ -161,7 +169,8 @@ class MockToolGroupsAPI:
|
|||
parameters=[],
|
||||
)
|
||||
]
|
||||
return []
|
||||
)
|
||||
return ListToolsResponse(data=[])
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return Tool(
|
||||
|
@ -174,7 +183,7 @@ class MockToolGroupsAPI:
|
|||
parameters=[],
|
||||
)
|
||||
|
||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||
async def unregister_tool_group(self, toolgroup_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
|
|||
chat_agent = await impl.get_agent(response.agent_id)
|
||||
|
||||
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||
tool_defs_names = [t.tool_name for t in tool_defs]
|
||||
if expected_memory:
|
||||
assert MEMORY_QUERY_TOOL in tool_defs
|
||||
assert MEMORY_QUERY_TOOL in tool_defs_names
|
||||
if expected_code_interpreter:
|
||||
assert BuiltinTool.code_interpreter in tool_defs
|
||||
assert BuiltinTool.code_interpreter in tool_defs_names
|
||||
if expected_memory and expected_code_interpreter:
|
||||
# override the tools for turn
|
||||
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
||||
|
@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
|
|||
)
|
||||
]
|
||||
)
|
||||
assert MEMORY_QUERY_TOOL in new_tool_defs
|
||||
assert BuiltinTool.code_interpreter not in new_tool_defs
|
||||
new_tool_defs_names = [t.tool_name for t in new_tool_defs]
|
||||
assert MEMORY_QUERY_TOOL in new_tool_defs_names
|
||||
assert BuiltinTool.code_interpreter not in new_tool_defs_names
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue