From 450bd60517f7a1fbf9604be6cc2d848a60b5c252 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Feb 2025 15:15:59 -0800 Subject: [PATCH] remove hardcoded all-mini --- .../providers/inline/agents/meta_reference/__init__.py | 1 + .../inline/agents/meta_reference/agent_instance.py | 8 +++++++- .../providers/inline/agents/meta_reference/agents.py | 4 ++++ llama_stack/providers/registry/agents.py | 1 + 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 8f8c24170..5d3835ce9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Ap deps[Api.safety], deps[Api.tool_runtime], deps[Api.tool_groups], + deps[Api.models], ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index fc597d0f7..730b0f1f5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -59,6 +59,7 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) +from llama_stack.apis.models import Models from llama_stack.apis.safety import Safety from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO @@ -94,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin): tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, vector_io_api: VectorIO, + models_api: Models, persistence_store: KVStore, ): self.agent_id = agent_id @@ -102,6 +104,7 @@ class ChatAgent(ShieldRunnerMixin): self.inference_api = inference_api self.safety_api = safety_api self.vector_io_api = vector_io_api + self.models_api = models_api self.storage = AgentPersistence(agent_id, persistence_store) self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api @@ -825,9 +828,12 @@ class ChatAgent(ShieldRunnerMixin): # TODO: the semantic for registration is definitely not "creation" # so we need to fix it if we expect the agent to create a new vector db # for each session + list_models_response = await self.models_api.list_models() + embdding_models = [x for x in list_models_response.data if x.model_type == "embedding"] await self.vector_io_api.register_vector_db( vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", + embedding_model=embdding_models[0].identifier, + embedding_dimension=embdding_models[0].metadata["embedding_dimension"], ) await self.storage.add_vector_db_to_session(session_id, vector_db_id) else: diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index e3c18d112..69418fe95 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -31,6 +31,7 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) +from llama_stack.apis.models import Models from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO @@ -52,6 +53,7 @@ class MetaReferenceAgentsImpl(Agents): safety_api: Safety, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, + models_api: Models, ): self.config = config self.inference_api = inference_api @@ -59,6 +61,7 @@ class MetaReferenceAgentsImpl(Agents): self.safety_api = safety_api self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api + self.models_api = models_api self.in_memory_store = InmemoryKVStoreImpl() self.tempdir = tempfile.mkdtemp() @@ -115,6 +118,7 @@ class MetaReferenceAgentsImpl(Agents): vector_io_api=self.vector_io_api, tool_runtime_api=self.tool_runtime_api, tool_groups_api=self.tool_groups_api, + models_api=self.models_api, persistence_store=( self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store ), diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 655303f98..3d2f2d584 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -37,6 +37,7 @@ def available_providers() -> List[ProviderSpec]: Api.vector_dbs, Api.tool_runtime, Api.tool_groups, + Api.models, ], ), remote_provider_spec(