remove hardcoded all-mini

This commit is contained in:
Xi Yan 2025-02-15 15:15:59 -08:00
parent 0293b18f55
commit 450bd60517
4 changed files with 13 additions and 1 deletions

View file

@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Ap
deps[Api.safety], deps[Api.safety],
deps[Api.tool_runtime], deps[Api.tool_runtime],
deps[Api.tool_groups], deps[Api.tool_groups],
deps[Api.models],
) )
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -59,6 +59,7 @@ from llama_stack.apis.inference import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
@ -94,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
vector_io_api: VectorIO, vector_io_api: VectorIO,
models_api: Models,
persistence_store: KVStore, persistence_store: KVStore,
): ):
self.agent_id = agent_id self.agent_id = agent_id
@ -102,6 +104,7 @@ class ChatAgent(ShieldRunnerMixin):
self.inference_api = inference_api self.inference_api = inference_api
self.safety_api = safety_api self.safety_api = safety_api
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
self.models_api = models_api
self.storage = AgentPersistence(agent_id, persistence_store) self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
@ -825,9 +828,12 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: the semantic for registration is definitely not "creation" # TODO: the semantic for registration is definitely not "creation"
# so we need to fix it if we expect the agent to create a new vector db # so we need to fix it if we expect the agent to create a new vector db
# for each session # for each session
list_models_response = await self.models_api.list_models()
embdding_models = [x for x in list_models_response.data if x.model_type == "embedding"]
await self.vector_io_api.register_vector_db( await self.vector_io_api.register_vector_db(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model=embdding_models[0].identifier,
embedding_dimension=embdding_models[0].metadata["embedding_dimension"],
) )
await self.storage.add_vector_db_to_session(session_id, vector_db_id) await self.storage.add_vector_db_to_session(session_id, vector_db_id)
else: else:

View file

@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
@ -52,6 +53,7 @@ class MetaReferenceAgentsImpl(Agents):
safety_api: Safety, safety_api: Safety,
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
models_api: Models,
): ):
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
@ -59,6 +61,7 @@ class MetaReferenceAgentsImpl(Agents):
self.safety_api = safety_api self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.models_api = models_api
self.in_memory_store = InmemoryKVStoreImpl() self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp() self.tempdir = tempfile.mkdtemp()
@ -115,6 +118,7 @@ class MetaReferenceAgentsImpl(Agents):
vector_io_api=self.vector_io_api, vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api, tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api, tool_groups_api=self.tool_groups_api,
models_api=self.models_api,
persistence_store=( persistence_store=(
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
), ),

View file

@ -37,6 +37,7 @@ def available_providers() -> List[ProviderSpec]:
Api.vector_dbs, Api.vector_dbs,
Api.tool_runtime, Api.tool_runtime,
Api.tool_groups, Api.tool_groups,
Api.models,
], ],
), ),
remote_provider_spec( remote_provider_spec(