diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py index 17fc26632..78b5408a4 100644 --- a/llama_stack/templates/cerebras/cerebras.py +++ b/llama_stack/templates/cerebras/cerebras.py @@ -18,6 +18,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) +from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -41,6 +42,7 @@ def get_distribution_template() -> DistributionTemplate: ], } + name = "cerebras" inference_provider = Provider( provider_id="cerebras", provider_type="remote::cerebras", @@ -71,6 +73,11 @@ def get_distribution_template() -> DistributionTemplate: "embedding_dimension": 384, }, ) + memory_provider = Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + ) default_tool_groups = [ ToolGroupInput( toolgroup_id="builtin::websearch", @@ -98,6 +105,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": [inference_provider, embedding_provider], + "memory": [memory_provider], }, default_models=default_models + [embedding_model], default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],