mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
remove hardcoded all-mini
This commit is contained in:
parent
0293b18f55
commit
450bd60517
4 changed files with 13 additions and 1 deletions
|
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Ap
|
|||
deps[Api.safety],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
deps[Api.models],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -59,6 +59,7 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
|
@ -94,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
vector_io_api: VectorIO,
|
||||
models_api: Models,
|
||||
persistence_store: KVStore,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
|
@ -102,6 +104,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.inference_api = inference_api
|
||||
self.safety_api = safety_api
|
||||
self.vector_io_api = vector_io_api
|
||||
self.models_api = models_api
|
||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
@ -825,9 +828,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# TODO: the semantic for registration is definitely not "creation"
|
||||
# so we need to fix it if we expect the agent to create a new vector db
|
||||
# for each session
|
||||
list_models_response = await self.models_api.list_models()
|
||||
embdding_models = [x for x in list_models_response.data if x.model_type == "embedding"]
|
||||
await self.vector_io_api.register_vector_db(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_model=embdding_models[0].identifier,
|
||||
embedding_dimension=embdding_models[0].metadata["embedding_dimension"],
|
||||
)
|
||||
await self.storage.add_vector_db_to_session(session_id, vector_db_id)
|
||||
else:
|
||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
|
@ -52,6 +53,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
models_api: Models,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
@ -59,6 +61,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.safety_api = safety_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.models_api = models_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
|
@ -115,6 +118,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
vector_io_api=self.vector_io_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
models_api=self.models_api,
|
||||
persistence_store=(
|
||||
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
|
|
|
@ -37,6 +37,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.vector_dbs,
|
||||
Api.tool_runtime,
|
||||
Api.tool_groups,
|
||||
Api.models,
|
||||
],
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue