mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 00:05:18 +00:00
Allow specifying resources in StackRunConfig
This commit is contained in:
parent
8035fa1869
commit
38257a9cbe
9 changed files with 151 additions and 102 deletions
|
@ -31,48 +31,7 @@ from .strong_typing.schema import json_schema_type
|
||||||
|
|
||||||
schema_utils.json_schema_type = json_schema_type
|
schema_utils.json_schema_type = json_schema_type
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_stack.distribution.stack import LlamaStack
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
|
||||||
from llama_stack.apis.scoring import * # noqa: F403
|
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
|
||||||
from llama_stack.apis.eval import * # noqa: F403
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
|
||||||
from llama_stack.apis.batch_inference import * # noqa: F403
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
|
||||||
from llama_stack.apis.telemetry import * # noqa: F403
|
|
||||||
from llama_stack.apis.post_training import * # noqa: F403
|
|
||||||
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
|
||||||
from llama_stack.apis.models import * # noqa: F403
|
|
||||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
|
||||||
from llama_stack.apis.shields import * # noqa: F403
|
|
||||||
from llama_stack.apis.inspect import * # noqa: F403
|
|
||||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
|
||||||
MemoryBanks,
|
|
||||||
Inference,
|
|
||||||
BatchInference,
|
|
||||||
Agents,
|
|
||||||
Safety,
|
|
||||||
SyntheticDataGeneration,
|
|
||||||
Datasets,
|
|
||||||
Telemetry,
|
|
||||||
PostTraining,
|
|
||||||
Memory,
|
|
||||||
Eval,
|
|
||||||
EvalTasks,
|
|
||||||
Scoring,
|
|
||||||
ScoringFunctions,
|
|
||||||
DatasetIO,
|
|
||||||
Models,
|
|
||||||
Shields,
|
|
||||||
Inspect,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
|
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
|
||||||
|
|
|
@ -22,6 +22,9 @@ class ResourceType(Enum):
|
||||||
class Resource(BaseModel):
|
class Resource(BaseModel):
|
||||||
"""Base class for all Llama Stack resources"""
|
"""Base class for all Llama Stack resources"""
|
||||||
|
|
||||||
|
# TODO: I think we need to move these into the child classes
|
||||||
|
# and make them `model_id`, `shield_id`, etc. because otherwise
|
||||||
|
# the config file has these confusing generic names in there
|
||||||
identifier: str = Field(
|
identifier: str = Field(
|
||||||
description="Unique identifier for this resource in llama stack"
|
description="Unique identifier for this resource in llama stack"
|
||||||
)
|
)
|
||||||
|
|
|
@ -151,6 +151,14 @@ Configuration for the persistence store used by the distribution registry. If no
|
||||||
a default SQLite store will be used.""",
|
a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# registry of "resources" in the distribution
|
||||||
|
models: List[Model] = Field(default_factory=list)
|
||||||
|
shields: List[Shield] = Field(default_factory=list)
|
||||||
|
memory_banks: List[MemoryBank] = Field(default_factory=list)
|
||||||
|
datasets: List[Dataset] = Field(default_factory=list)
|
||||||
|
scoring_fns: List[ScoringFn] = Field(default_factory=list)
|
||||||
|
eval_tasks: List[EvalTask] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||||
|
|
|
@ -27,12 +27,7 @@ from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
builtin_automatically_routed_apis,
|
|
||||||
get_provider_registry,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -42,14 +37,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
|
from llama_stack.distribution.stack import construct_stack
|
||||||
|
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
|
||||||
def create_sse_event(data: Any) -> str:
|
def create_sse_event(data: Any) -> str:
|
||||||
if isinstance(data, BaseModel):
|
if isinstance(data, BaseModel):
|
||||||
data = data.json()
|
data = data.model_dump_json()
|
||||||
else:
|
else:
|
||||||
data = json.dumps(data)
|
data = json.dumps(data)
|
||||||
|
|
||||||
|
@ -281,12 +277,8 @@ def main(
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(
|
impls = asyncio.run(construct_stack(config))
|
||||||
resolve_impls(config, get_provider_registry(), dist_registry)
|
|
||||||
)
|
|
||||||
except InvalidProviderError:
|
except InvalidProviderError:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
79
llama_stack/distribution/stack.py
Normal file
79
llama_stack/distribution/stack.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F403
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
|
from llama_stack.apis.eval import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
from llama_stack.apis.post_training import * # noqa: F403
|
||||||
|
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||||
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
|
from llama_stack.apis.shields import * # noqa: F403
|
||||||
|
from llama_stack.apis.inspect import * # noqa: F403
|
||||||
|
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
from llama_stack.distribution.resolver import resolve_impls
|
||||||
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaStack(
|
||||||
|
MemoryBanks,
|
||||||
|
Inference,
|
||||||
|
BatchInference,
|
||||||
|
Agents,
|
||||||
|
Safety,
|
||||||
|
SyntheticDataGeneration,
|
||||||
|
Datasets,
|
||||||
|
Telemetry,
|
||||||
|
PostTraining,
|
||||||
|
Memory,
|
||||||
|
Eval,
|
||||||
|
EvalTasks,
|
||||||
|
Scoring,
|
||||||
|
ScoringFunctions,
|
||||||
|
DatasetIO,
|
||||||
|
Models,
|
||||||
|
Shields,
|
||||||
|
Inspect,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
|
# asked for in the run config.
|
||||||
|
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
|
dist_registry, _ = await create_dist_registry(
|
||||||
|
run_config.metadata_store, run_config.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||||
|
|
||||||
|
objects = [
|
||||||
|
*run_config.models,
|
||||||
|
*run_config.shields,
|
||||||
|
*run_config.memory_banks,
|
||||||
|
*run_config.datasets,
|
||||||
|
*run_config.scoring_fns,
|
||||||
|
*run_config.eval_tasks,
|
||||||
|
]
|
||||||
|
for obj in objects:
|
||||||
|
await dist_registry.register(obj)
|
||||||
|
|
||||||
|
return impls
|
|
@ -5,14 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List, Protocol
|
from typing import Dict, List, Optional, Protocol
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||||
RoutableObjectWithProvider,
|
|
||||||
StackRunConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import (
|
from llama_stack.providers.utils.kvstore import (
|
||||||
|
@ -144,17 +141,16 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
|
|
||||||
|
|
||||||
async def create_dist_registry(
|
async def create_dist_registry(
|
||||||
config: StackRunConfig,
|
metadata_store: Optional[KVStoreConfig],
|
||||||
|
image_name: str,
|
||||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||||
# instantiate kvstore for storing and retrieving distribution metadata
|
# instantiate kvstore for storing and retrieving distribution metadata
|
||||||
if config.metadata_store:
|
if metadata_store:
|
||||||
dist_kvstore = await kvstore_impl(config.metadata_store)
|
dist_kvstore = await kvstore_impl(metadata_store)
|
||||||
else:
|
else:
|
||||||
dist_kvstore = await kvstore_impl(
|
dist_kvstore = await kvstore_impl(
|
||||||
SqliteKVStoreConfig(
|
SqliteKVStoreConfig(
|
||||||
db_path=(
|
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
|
||||||
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
|
|
||||||
).as_posix()
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,8 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.models import Model
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.inference.meta_reference import (
|
from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
MetaReferenceInferenceConfig,
|
MetaReferenceInferenceConfig,
|
||||||
|
@ -159,13 +161,13 @@ async def inference_stack(request, inference_model):
|
||||||
[Api.inference],
|
[Api.inference],
|
||||||
{"inference": inference_fixture.providers},
|
{"inference": inference_fixture.providers},
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
)
|
models=[
|
||||||
|
Model(
|
||||||
provider_id = inference_fixture.providers[0].provider_id
|
identifier=inference_model,
|
||||||
print(f"Registering model {inference_model} with provider {provider_id}")
|
provider_resource_id=inference_model,
|
||||||
await impls[Api.models].register_model(
|
provider_id=inference_fixture.providers[0].provider_id,
|
||||||
model_id=inference_model,
|
)
|
||||||
provider_id=provider_id,
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return (impls[Api.inference], impls[Api.models])
|
return (impls[Api.inference], impls[Api.models])
|
||||||
|
|
|
@ -17,29 +17,38 @@ from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.stack import construct_stack
|
||||||
from llama_stack.distribution.store import CachedDiskDistributionRegistry
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_for_test_v2(
|
async def resolve_impls_for_test_v2(
|
||||||
apis: List[Api],
|
apis: List[Api],
|
||||||
providers: Dict[str, List[Provider]],
|
providers: Dict[str, List[Provider]],
|
||||||
provider_data: Optional[Dict[str, Any]] = None,
|
provider_data: Optional[Dict[str, Any]] = None,
|
||||||
|
models: Optional[List[Model]] = None,
|
||||||
|
shields: Optional[List[Shield]] = None,
|
||||||
|
memory_banks: Optional[List[MemoryBank]] = None,
|
||||||
|
datasets: Optional[List[Dataset]] = None,
|
||||||
|
scoring_fns: Optional[List[ScoringFn]] = None,
|
||||||
|
eval_tasks: Optional[List[EvalTask]] = None,
|
||||||
):
|
):
|
||||||
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
run_config = dict(
|
run_config = dict(
|
||||||
built_at=datetime.now(),
|
built_at=datetime.now(),
|
||||||
image_name="test-fixture",
|
image_name="test-fixture",
|
||||||
apis=apis,
|
apis=apis,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
|
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
|
||||||
|
models=models or [],
|
||||||
|
shields=shields or [],
|
||||||
|
memory_banks=memory_banks or [],
|
||||||
|
datasets=datasets or [],
|
||||||
|
scoring_fns=scoring_fns or [],
|
||||||
|
eval_tasks=eval_tasks or [],
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
|
|
||||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
||||||
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
|
|
||||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
|
||||||
try:
|
try:
|
||||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
impls = await construct_stack(run_config)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
print_pip_install_help(providers)
|
print_pip_install_help(providers)
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -7,7 +7,9 @@
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.shields import ShieldType
|
from llama_stack.apis.models import Model
|
||||||
|
|
||||||
|
from llama_stack.apis.shields import Shield, ShieldType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||||
|
@ -96,32 +98,29 @@ async def safety_stack(inference_model, safety_model, request):
|
||||||
if safety_fixture.provider_data:
|
if safety_fixture.provider_data:
|
||||||
provider_data.update(safety_fixture.provider_data)
|
provider_data.update(safety_fixture.provider_data)
|
||||||
|
|
||||||
|
shield_provider_type = safety_fixture.providers[0].provider_type
|
||||||
|
shield = get_shield(
|
||||||
|
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
|
||||||
|
)
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
impls = await resolve_impls_for_test_v2(
|
||||||
[Api.safety, Api.shields, Api.inference],
|
[Api.safety, Api.shields, Api.inference],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
|
models=[
|
||||||
|
Model(
|
||||||
|
identifier=inference_model,
|
||||||
|
provider_id=inference_fixture.providers[0].provider_id,
|
||||||
|
provider_resource_id=inference_model,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
shields=[shield],
|
||||||
)
|
)
|
||||||
|
|
||||||
safety_impl = impls[Api.safety]
|
return impls[Api.safety], impls[Api.shields], shield
|
||||||
shields_impl = impls[Api.shields]
|
|
||||||
|
|
||||||
# Register the appropriate shield based on provider type
|
|
||||||
provider_type = safety_fixture.providers[0].provider_type
|
|
||||||
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
|
|
||||||
|
|
||||||
provider_id = inference_fixture.providers[0].provider_id
|
|
||||||
print(f"Registering model {inference_model} with provider {provider_id}")
|
|
||||||
await impls[Api.models].register_model(
|
|
||||||
model_id=inference_model,
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return safety_impl, shields_impl, shield
|
|
||||||
|
|
||||||
|
|
||||||
async def create_and_register_shield(
|
def get_shield(provider_type: str, provider_id: str, safety_model: str):
|
||||||
provider_type: str, safety_model: str, shields_impl
|
|
||||||
):
|
|
||||||
shield_config = {}
|
shield_config = {}
|
||||||
shield_type = ShieldType.llama_guard
|
shield_type = ShieldType.llama_guard
|
||||||
identifier = "llama_guard"
|
identifier = "llama_guard"
|
||||||
|
@ -134,8 +133,10 @@ async def create_and_register_shield(
|
||||||
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
||||||
shield_type = ShieldType.generic_content_shield
|
shield_type = ShieldType.generic_content_shield
|
||||||
|
|
||||||
return await shields_impl.register_shield(
|
return Shield(
|
||||||
shield_id=identifier,
|
identifier=identifier,
|
||||||
shield_type=shield_type,
|
shield_type=shield_type,
|
||||||
params=shield_config,
|
params=shield_config,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=identifier,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue