llama-stack/llama_stack/templates/bedrock/bedrock.py
2025-03-23 15:48:14 -07:00

85 lines
2.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.bedrock.models import MODEL_ENTRIES
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
get_model_registry,
)
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::bedrock"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["remote::bedrock"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
name = "bedrock"
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
available_models = {
"bedrock": MODEL_ENTRIES,
}
default_models = get_model_registry(available_models)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Use AWS Bedrock for running LLM inference and safety",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"vector_io": [vector_io_provider],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
},
)