This commit is contained in:
Xi Yan 2025-03-18 21:49:11 -07:00
parent 011fd59a29
commit 8a576d7d72
24 changed files with 297 additions and 2525 deletions

View file

@ -11,8 +11,8 @@ from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOCon
from llama_stack.providers.remote.inference.bedrock.models import MODEL_ENTRIES
from llama_stack.templates.template import (
DistributionTemplate,
get_model_registry,
RunConfigSettings,
get_model_registry,
)

View file

@ -16,8 +16,8 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.cerebras.models import MODEL_ENTRIES
from llama_stack.templates.template import (
DistributionTemplate,
get_model_registry,
RunConfigSettings,
get_model_registry,
)

View file

@ -22,8 +22,8 @@ from llama_stack.providers.remote.inference.fireworks.config import FireworksImp
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
from llama_stack.templates.template import (
DistributionTemplate,
get_model_registry,
RunConfigSettings,
get_model_registry,
)

View file

@ -45,8 +45,8 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
)
from llama_stack.templates.template import (
DistributionTemplate,
get_model_registry,
RunConfigSettings,
get_model_registry,
)
@ -96,10 +96,7 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
def get_distribution_template() -> DistributionTemplate:
inference_providers, available_models = get_inference_providers()
providers = {
"inference": (
[p.provider_type for p in inference_providers]
+ ["inline::sentence-transformers"]
),
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
@ -119,9 +116,7 @@ def get_distribution_template() -> DistributionTemplate:
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}"
),
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",

View file

@ -21,8 +21,8 @@ from llama_stack.providers.remote.inference.fireworks.config import FireworksImp
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
from llama_stack.templates.template import (
DistributionTemplate,
get_model_registry,
RunConfigSettings,
get_model_registry,
)

View file

@ -15,8 +15,8 @@ from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES
from llama_stack.templates.template import (
DistributionTemplate,
get_model_registry,
RunConfigSettings,
get_model_registry,
)

View file

@ -17,8 +17,8 @@ from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import (
DistributionTemplate,
get_model_registry,
RunConfigSettings,
get_model_registry,
)
@ -87,9 +87,7 @@ def get_distribution_template() -> DistributionTemplate:
]
},
default_models=[inference_model, safety_model],
default_shields=[
ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
default_tool_groups=default_tool_groups,
),
},

View file

@ -9,7 +9,6 @@ from typing import Dict, List, Tuple
from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
BenchmarkInput,
DatasetInput,
ModelInput,
Provider,
@ -31,14 +30,12 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.templates.template import (
DistributionTemplate,
get_model_registry,
RunConfigSettings,
get_model_registry,
)
def get_inference_providers() -> (
Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]
):
def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
# in this template, we allow each API key to be optional
providers = [
(
@ -119,9 +116,7 @@ def get_distribution_template() -> DistributionTemplate:
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}"
),
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",

View file

@ -21,8 +21,8 @@ from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.together.models import MODEL_ENTRIES
from llama_stack.templates.template import (
DistributionTemplate,
get_model_registry,
RunConfigSettings,
get_model_registry,
)