Integrate distro docs into the restructured docs

This commit is contained in:
Ashwin Bharambe 2024-11-20 23:20:05 -08:00
parent 2411a44833
commit cd6ccb664c
17 changed files with 306 additions and 115 deletions

View file

@ -6,16 +6,16 @@
from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.distribution.datatypes import ModelInput, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
MetaReferenceQuantizedInferenceConfig,
)
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["inline::meta-reference"],
"inference": ["inline::meta-reference-quantized"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
@ -24,8 +24,8 @@ def get_distribution_template() -> DistributionTemplate:
inference_provider = Provider(
provider_id="meta-reference-inference",
provider_type="inline::meta-reference",
config=MetaReferenceInferenceConfig.sample_run_config(
provider_type="inline::meta-reference-quantized",
config=MetaReferenceQuantizedInferenceConfig.sample_run_config(
model="${env.INFERENCE_MODEL}",
checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}",
),
@ -35,18 +35,13 @@ def get_distribution_template() -> DistributionTemplate:
model_id="${env.INFERENCE_MODEL}",
provider_id="meta-reference-inference",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="meta-reference-safety",
)
return DistributionTemplate(
name="meta-reference-gpu",
name="meta-reference-quantized-gpu",
distro_type="self_hosted",
description="Use Meta Reference for running LLM inference",
description="Use Meta Reference with fp8, int4 quantization for running LLM inference",
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
default_models=[inference_model, safety_model],
default_models=[inference_model],
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
@ -54,26 +49,6 @@ def get_distribution_template() -> DistributionTemplate:
},
default_models=[inference_model],
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
Provider(
provider_id="meta-reference-safety",
provider_type="inline::meta-reference",
config=MetaReferenceInferenceConfig.sample_run_config(
model="${env.SAFETY_MODEL}",
checkpoint_dir="${env.SAFETY_CHECKPOINT_DIR:null}",
),
),
],
},
default_models=[
inference_model,
safety_model,
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
@ -88,13 +63,5 @@ def get_distribution_template() -> DistributionTemplate:
"null",
"Directory containing the Meta Reference model checkpoint",
),
"SAFETY_MODEL": (
"meta-llama/Llama-Guard-3-1B",
"Name of the safety (Llama-Guard) model to use",
),
"SAFETY_CHECKPOINT_DIR": (
"null",
"Directory containing the Llama-Guard model checkpoint",
),
},
)