diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index 08d68fefa..7e56875fd 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -93,14 +93,7 @@ agent_config = AgentConfig( { "name": "builtin::rag/knowledge_search", "args": { - # 'documents_db_id' holds the ID of the registered vector database - # where the provided documents will be ingested. This argument is mandatory - # when the 'documents' parameter is provided in a 'create_turn' invocation. - # When provided, 'documents_db_id' will also be used to extract contextual information - # for the query. - "documents_db_id": vector_db_id, - # Optionally, the 'vector_db_ids' argument can specify additional vector databases - # to use at query time. + "vector_db_ids": vector_db_id, }, } ], diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index b802937b6..84ab364b7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -16,10 +16,11 @@ from llama_stack.apis.agents import ( AgentTurnResponseTurnCompletePayload, StepType, ) -from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.content_types import URL, TextDelta from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEvent, + ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, LogProbConfig, @@ -27,12 +28,15 @@ from llama_stack.apis.inference import ( ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, UserMessage, ) from llama_stack.apis.safety import RunShieldResponse from llama_stack.apis.tools import ( + ListToolGroupsResponse, + ListToolsResponse, Tool, ToolDef, ToolGroup, @@ -40,7 +44,7 @@ from llama_stack.apis.tools import ( ToolInvocationResult, ) from llama_stack.apis.vector_io import QueryChunksResponse -from llama_stack.models.llama.datatypes import BuiltinTool +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) @@ -54,36 +58,37 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class MockInferenceAPI: async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: async def stream_response(): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="start", - delta="", + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="progress", - delta="AI is a fascinating field...", + event_type=ChatCompletionResponseEventType.progress, + delta=TextDelta(text="AI is a fascinating field..."), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="complete", - delta="", - stop_reason="end_of_turn", + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=""), + stop_reason=StopReason.end_of_turn, ) ) @@ -133,35 +138,39 @@ class MockToolGroupsAPI: provider_resource_id=toolgroup_id, ) - async def list_tool_groups(self) -> List[ToolGroup]: - return [] + async def list_tool_groups(self) -> ListToolGroupsResponse: + return ListToolGroupsResponse(data=[]) - async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: - if tool_group_id == MEMORY_TOOLGROUP: - return [ - Tool( - identifier=MEMORY_QUERY_TOOL, - provider_resource_id=MEMORY_QUERY_TOOL, - toolgroup_id=MEMORY_TOOLGROUP, - tool_host=ToolHost.client, - description="Mock tool", - provider_id="builtin::rag", - parameters=[], - ) - ] - if tool_group_id == CODE_INTERPRETER_TOOLGROUP: - return [ - Tool( - identifier="code_interpreter", - provider_resource_id="code_interpreter", - toolgroup_id=CODE_INTERPRETER_TOOLGROUP, - tool_host=ToolHost.client, - description="Mock tool", - provider_id="builtin::code_interpreter", - parameters=[], - ) - ] - return [] + async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse: + if toolgroup_id == MEMORY_TOOLGROUP: + return ListToolsResponse( + data=[ + Tool( + identifier=MEMORY_QUERY_TOOL, + provider_resource_id=MEMORY_QUERY_TOOL, + toolgroup_id=MEMORY_TOOLGROUP, + tool_host=ToolHost.client, + description="Mock tool", + provider_id="builtin::rag", + parameters=[], + ) + ] + ) + if toolgroup_id == CODE_INTERPRETER_TOOLGROUP: + return ListToolsResponse( + data=[ + Tool( + identifier="code_interpreter", + provider_resource_id="code_interpreter", + toolgroup_id=CODE_INTERPRETER_TOOLGROUP, + tool_host=ToolHost.client, + description="Mock tool", + provider_id="builtin::code_interpreter", + parameters=[], + ) + ] + ) + return ListToolsResponse(data=[]) async def get_tool(self, tool_name: str) -> Tool: return Tool( @@ -174,7 +183,7 @@ class MockToolGroupsAPI: parameters=[], ) - async def unregister_tool_group(self, tool_group_id: str) -> None: + async def unregister_tool_group(self, toolgroup_id: str) -> None: pass @@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex chat_agent = await impl.get_agent(response.agent_id) tool_defs, _ = await chat_agent._get_tool_defs() + tool_defs_names = [t.tool_name for t in tool_defs] if expected_memory: - assert MEMORY_QUERY_TOOL in tool_defs + assert MEMORY_QUERY_TOOL in tool_defs_names if expected_code_interpreter: - assert BuiltinTool.code_interpreter in tool_defs + assert BuiltinTool.code_interpreter in tool_defs_names if expected_memory and expected_code_interpreter: # override the tools for turn new_tool_defs, _ = await chat_agent._get_tool_defs( @@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex ) ] ) - assert MEMORY_QUERY_TOOL in new_tool_defs - assert BuiltinTool.code_interpreter not in new_tool_defs + new_tool_defs_names = [t.tool_name for t in new_tool_defs] + assert MEMORY_QUERY_TOOL in new_tool_defs_names + assert BuiltinTool.code_interpreter not in new_tool_defs_names