mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-16 09:58:10 +00:00
Auto-generate distro yamls + docs (#468)
# What does this PR do? Automatically generates - build.yaml - run.yaml - run-with-safety.yaml - parts of markdown docs for the distributions. ## Test Plan At this point, this only updates the YAMLs and the docs. Some testing (especially with ollama and vllm) has been performed but needs to be much more tested.
This commit is contained in:
parent
0784284ab5
commit
2a31163178
88 changed files with 3008 additions and 852 deletions
|
@ -4,11 +4,22 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||
persistence_store: KVStoreConfig = Field(default=SqliteKVStoreConfig())
|
||||
persistence_store: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="agents_store.db",
|
||||
)
|
||||
}
|
||||
|
|
|
@ -49,6 +49,18 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
resolved = resolve_model(self.model)
|
||||
return resolved.pth_file_count
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": model,
|
||||
"max_seq_len": 4096,
|
||||
"checkpoint_dir": checkpoint_dir,
|
||||
}
|
||||
|
||||
|
||||
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
|
||||
quantization: QuantizationConfig
|
||||
|
|
|
@ -107,7 +107,7 @@ class Llama:
|
|||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
if config.checkpoint_dir:
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
|
@ -137,7 +137,6 @@ class Llama:
|
|||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
from .quantization.loader import convert_to_fp8_quantized_model
|
||||
|
||||
|
|
|
@ -34,6 +34,16 @@ class VLLMConfig(BaseModel):
|
|||
default=0.3,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
return {
|
||||
"model": "${env.VLLM_INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
||||
"tensor_parallel_size": "${env.VLLM_TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||
"enforce_eager": "${env.VLLM_ENFORCE_EAGER:False}",
|
||||
"gpu_memory_utilization": "${env.VLLM_GPU_MEMORY_UTILIZATION:0.3}",
|
||||
}
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
|
|
|
@ -4,10 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
@ -16,6 +17,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
@json_schema_type
|
||||
class FaissImplConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "faiss_store.db").as_posix()
|
||||
) # Uses SQLite config specific to FAISS storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="faiss_store.db",
|
||||
)
|
||||
}
|
||||
|
|
|
@ -73,18 +73,21 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
|||
CAT_ELECTIONS,
|
||||
]
|
||||
|
||||
LLAMA_GUARD_MODEL_IDS = [
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
]
|
||||
# accept both CoreModelId and huggingface repo id
|
||||
LLAMA_GUARD_MODEL_IDS = {
|
||||
CoreModelId.llama_guard_3_8b.value: "meta-llama/Llama-Guard-3-8B",
|
||||
"meta-llama/Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_1b.value: "meta-llama/Llama-Guard-3-1B",
|
||||
"meta-llama/Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
||||
CoreModelId.llama_guard_3_11b_vision.value: "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
"meta-llama/Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
}
|
||||
|
||||
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
||||
CoreModelId.llama_guard_3_8b.value: (
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
),
|
||||
CoreModelId.llama_guard_3_1b.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
CoreModelId.llama_guard_3_11b_vision.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES
|
||||
+ [CAT_CODE_INTERPRETER_ABUSE],
|
||||
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
}
|
||||
|
||||
|
||||
|
@ -150,8 +153,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id]
|
||||
impl = LlamaGuardShield(
|
||||
model=shield.provider_resource_id,
|
||||
model=model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -20,3 +20,10 @@ class FireworksImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference",
|
||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
model_aliases = [
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"fireworks/llama-v3p1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
|
@ -79,7 +79,7 @@ class FireworksInferenceAdapter(
|
|||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_aliases)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
|
|
|
@ -4,14 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
from .config import OllamaImplConfig
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteProviderConfig):
|
||||
port: int = 11434
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||
async def get_adapter_impl(config: OllamaImplConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config.url)
|
||||
|
|
22
llama_stack/providers/remote/inference/ollama/config.py
Normal file
22
llama_stack/providers/remote/inference/ollama/config.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
return {"url": url}
|
|
@ -82,7 +82,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print("Initializing Ollama, checking connectivity to server...")
|
||||
print(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
try:
|
||||
await self.client.ps()
|
||||
except httpx.ConnectError as e:
|
||||
|
|
|
@ -12,19 +12,20 @@ from pydantic import BaseModel, Field
|
|||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 8080
|
||||
protocol: str = "http"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"{self.protocol}://{self.host}:{self.port}"
|
||||
|
||||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A bearer token if your TGI endpoint is protected.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceEndpointImplConfig(BaseModel):
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -20,3 +20,10 @@ class TogetherImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The Together AI API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY}",
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
model_aliases = [
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
|
@ -78,7 +78,7 @@ class TogetherInferenceAdapter(
|
|||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_aliases)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
|
|
|
@ -24,3 +24,15 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default="fake",
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.VLLM_URL}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"url": url,
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:fake}",
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import json
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
@ -37,7 +36,6 @@ async def construct_stack_for_test(
|
|||
) -> TestStack:
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=apis,
|
||||
providers=providers,
|
||||
|
|
|
@ -36,6 +36,15 @@ class RedisKVStoreConfig(CommonConfig):
|
|||
def url(self) -> str:
|
||||
return f"redis://{self.host}:{self.port}"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
return {
|
||||
"type": "redis",
|
||||
"namespace": None,
|
||||
"host": "${env.REDIS_HOST:localhost}",
|
||||
"port": "${env.REDIS_PORT:6379}",
|
||||
}
|
||||
|
||||
|
||||
class SqliteKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
|
||||
|
@ -44,6 +53,19 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
description="File path for the sqlite database",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"
|
||||
):
|
||||
return {
|
||||
"type": "sqlite",
|
||||
"namespace": None,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/"
|
||||
+ __distro_dir__
|
||||
+ "}/"
|
||||
+ db_name,
|
||||
}
|
||||
|
||||
|
||||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||
|
@ -54,6 +76,19 @@ class PostgresKVStoreConfig(CommonConfig):
|
|||
password: Optional[str] = None
|
||||
table_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, table_name: str = "llamastack_kvstore"):
|
||||
return {
|
||||
"type": "postgres",
|
||||
"namespace": None,
|
||||
"host": "${env.POSTGRES_HOST:localhost}",
|
||||
"port": "${env.POSTGRES_PORT:5432}",
|
||||
"db": "${env.POSTGRES_DB}",
|
||||
"user": "${env.POSTGRES_USER}",
|
||||
"password": "${env.POSTGRES_PASSWORD}",
|
||||
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@field_validator("table_name")
|
||||
def validate_table_name(cls, v: str) -> str:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue