diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index 03b71e057..acbc07ca4 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -122,7 +122,7 @@ response = agent.create_turn( ], documents=[ { - "content": "https://raw.githubusercontent.com/example/doc.rst", + "content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst", "mime_type": "text/plain", } ], diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 73f9c9672..1be43ec8b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -309,13 +309,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): if provider_vector_db_id is None: provider_vector_db_id = vector_db_id if provider_id is None: - # If provider_id not specified, use the only provider if it supports this shield type - if len(self.impls_by_provider_id) == 1: + if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] + if len(self.impls_by_provider_id) > 1: + logger.warning( + f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." + ) else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) + raise ValueError("No provider available. Please configure a vector_io provider.") model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index b802937b6..84ab364b7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -16,10 +16,11 @@ from llama_stack.apis.agents import ( AgentTurnResponseTurnCompletePayload, StepType, ) -from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.content_types import URL, TextDelta from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEvent, + ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, LogProbConfig, @@ -27,12 +28,15 @@ from llama_stack.apis.inference import ( ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, UserMessage, ) from llama_stack.apis.safety import RunShieldResponse from llama_stack.apis.tools import ( + ListToolGroupsResponse, + ListToolsResponse, Tool, ToolDef, ToolGroup, @@ -40,7 +44,7 @@ from llama_stack.apis.tools import ( ToolInvocationResult, ) from llama_stack.apis.vector_io import QueryChunksResponse -from llama_stack.models.llama.datatypes import BuiltinTool +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) @@ -54,36 +58,37 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class MockInferenceAPI: async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: async def stream_response(): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="start", - delta="", + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="progress", - delta="AI is a fascinating field...", + event_type=ChatCompletionResponseEventType.progress, + delta=TextDelta(text="AI is a fascinating field..."), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="complete", - delta="", - stop_reason="end_of_turn", + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=""), + stop_reason=StopReason.end_of_turn, ) ) @@ -133,35 +138,39 @@ class MockToolGroupsAPI: provider_resource_id=toolgroup_id, ) - async def list_tool_groups(self) -> List[ToolGroup]: - return [] + async def list_tool_groups(self) -> ListToolGroupsResponse: + return ListToolGroupsResponse(data=[]) - async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: - if tool_group_id == MEMORY_TOOLGROUP: - return [ - Tool( - identifier=MEMORY_QUERY_TOOL, - provider_resource_id=MEMORY_QUERY_TOOL, - toolgroup_id=MEMORY_TOOLGROUP, - tool_host=ToolHost.client, - description="Mock tool", - provider_id="builtin::rag", - parameters=[], - ) - ] - if tool_group_id == CODE_INTERPRETER_TOOLGROUP: - return [ - Tool( - identifier="code_interpreter", - provider_resource_id="code_interpreter", - toolgroup_id=CODE_INTERPRETER_TOOLGROUP, - tool_host=ToolHost.client, - description="Mock tool", - provider_id="builtin::code_interpreter", - parameters=[], - ) - ] - return [] + async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse: + if toolgroup_id == MEMORY_TOOLGROUP: + return ListToolsResponse( + data=[ + Tool( + identifier=MEMORY_QUERY_TOOL, + provider_resource_id=MEMORY_QUERY_TOOL, + toolgroup_id=MEMORY_TOOLGROUP, + tool_host=ToolHost.client, + description="Mock tool", + provider_id="builtin::rag", + parameters=[], + ) + ] + ) + if toolgroup_id == CODE_INTERPRETER_TOOLGROUP: + return ListToolsResponse( + data=[ + Tool( + identifier="code_interpreter", + provider_resource_id="code_interpreter", + toolgroup_id=CODE_INTERPRETER_TOOLGROUP, + tool_host=ToolHost.client, + description="Mock tool", + provider_id="builtin::code_interpreter", + parameters=[], + ) + ] + ) + return ListToolsResponse(data=[]) async def get_tool(self, tool_name: str) -> Tool: return Tool( @@ -174,7 +183,7 @@ class MockToolGroupsAPI: parameters=[], ) - async def unregister_tool_group(self, tool_group_id: str) -> None: + async def unregister_tool_group(self, toolgroup_id: str) -> None: pass @@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex chat_agent = await impl.get_agent(response.agent_id) tool_defs, _ = await chat_agent._get_tool_defs() + tool_defs_names = [t.tool_name for t in tool_defs] if expected_memory: - assert MEMORY_QUERY_TOOL in tool_defs + assert MEMORY_QUERY_TOOL in tool_defs_names if expected_code_interpreter: - assert BuiltinTool.code_interpreter in tool_defs + assert BuiltinTool.code_interpreter in tool_defs_names if expected_memory and expected_code_interpreter: # override the tools for turn new_tool_defs, _ = await chat_agent._get_tool_defs( @@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex ) ] ) - assert MEMORY_QUERY_TOOL in new_tool_defs - assert BuiltinTool.code_interpreter not in new_tool_defs + new_tool_defs_names = [t.tool_name for t in new_tool_defs] + assert MEMORY_QUERY_TOOL in new_tool_defs_names + assert BuiltinTool.code_interpreter not in new_tool_defs_names