mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
remove hardcoded all-mini
This commit is contained in:
parent
0293b18f55
commit
450bd60517
4 changed files with 13 additions and 1 deletions
|
@ -21,6 +21,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Ap
|
||||||
deps[Api.safety],
|
deps[Api.safety],
|
||||||
deps[Api.tool_runtime],
|
deps[Api.tool_runtime],
|
||||||
deps[Api.tool_groups],
|
deps[Api.tool_groups],
|
||||||
|
deps[Api.models],
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -59,6 +59,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
@ -94,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
vector_io_api: VectorIO,
|
vector_io_api: VectorIO,
|
||||||
|
models_api: Models,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
@ -102,6 +104,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.vector_io_api = vector_io_api
|
self.vector_io_api = vector_io_api
|
||||||
|
self.models_api = models_api
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
@ -825,9 +828,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: the semantic for registration is definitely not "creation"
|
# TODO: the semantic for registration is definitely not "creation"
|
||||||
# so we need to fix it if we expect the agent to create a new vector db
|
# so we need to fix it if we expect the agent to create a new vector db
|
||||||
# for each session
|
# for each session
|
||||||
|
list_models_response = await self.models_api.list_models()
|
||||||
|
embdding_models = [x for x in list_models_response.data if x.model_type == "embedding"]
|
||||||
await self.vector_io_api.register_vector_db(
|
await self.vector_io_api.register_vector_db(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=embdding_models[0].identifier,
|
||||||
|
embedding_dimension=embdding_models[0].metadata["embedding_dimension"],
|
||||||
)
|
)
|
||||||
await self.storage.add_vector_db_to_session(session_id, vector_db_id)
|
await self.storage.add_vector_db_to_session(session_id, vector_db_id)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
@ -52,6 +53,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
|
models_api: Models,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
@ -59,6 +61,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
self.models_api = models_api
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
self.tempdir = tempfile.mkdtemp()
|
self.tempdir = tempfile.mkdtemp()
|
||||||
|
@ -115,6 +118,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
vector_io_api=self.vector_io_api,
|
vector_io_api=self.vector_io_api,
|
||||||
tool_runtime_api=self.tool_runtime_api,
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
tool_groups_api=self.tool_groups_api,
|
tool_groups_api=self.tool_groups_api,
|
||||||
|
models_api=self.models_api,
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
|
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
|
||||||
),
|
),
|
||||||
|
|
|
@ -37,6 +37,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.vector_dbs,
|
Api.vector_dbs,
|
||||||
Api.tool_runtime,
|
Api.tool_runtime,
|
||||||
Api.tool_groups,
|
Api.tool_groups,
|
||||||
|
Api.models,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue