mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
Allow specifying resources in StackRunConfig
This commit is contained in:
parent
8035fa1869
commit
38257a9cbe
9 changed files with 151 additions and 102 deletions
|
@ -31,48 +31,7 @@ from .strong_typing.schema import json_schema_type
|
|||
|
||||
schema_utils.json_schema_type = json_schema_type
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
MemoryBanks,
|
||||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Telemetry,
|
||||
PostTraining,
|
||||
Memory,
|
||||
Eval,
|
||||
EvalTasks,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
DatasetIO,
|
||||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
):
|
||||
pass
|
||||
from llama_stack.distribution.stack import LlamaStack
|
||||
|
||||
|
||||
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
|
||||
|
|
|
@ -22,6 +22,9 @@ class ResourceType(Enum):
|
|||
class Resource(BaseModel):
|
||||
"""Base class for all Llama Stack resources"""
|
||||
|
||||
# TODO: I think we need to move these into the child classes
|
||||
# and make them `model_id`, `shield_id`, etc. because otherwise
|
||||
# the config file has these confusing generic names in there
|
||||
identifier: str = Field(
|
||||
description="Unique identifier for this resource in llama stack"
|
||||
)
|
||||
|
|
|
@ -151,6 +151,14 @@ Configuration for the persistence store used by the distribution registry. If no
|
|||
a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: List[Model] = Field(default_factory=list)
|
||||
shields: List[Shield] = Field(default_factory=list)
|
||||
memory_banks: List[MemoryBank] = Field(default_factory=list)
|
||||
datasets: List[Dataset] = Field(default_factory=list)
|
||||
scoring_fns: List[ScoringFn] = Field(default_factory=list)
|
||||
eval_tasks: List[EvalTask] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
|
|
@ -27,12 +27,7 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
@ -42,14 +37,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
)
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import construct_stack
|
||||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
data = data.model_dump_json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
|
@ -281,12 +277,8 @@ def main(
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
|
||||
|
||||
try:
|
||||
impls = asyncio.run(
|
||||
resolve_impls(config, get_provider_registry(), dist_registry)
|
||||
)
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError:
|
||||
sys.exit(1)
|
||||
|
||||
|
|
79
llama_stack/distribution/stack.py
Normal file
79
llama_stack/distribution/stack.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
MemoryBanks,
|
||||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Telemetry,
|
||||
PostTraining,
|
||||
Memory,
|
||||
Eval,
|
||||
EvalTasks,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
DatasetIO,
|
||||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(
|
||||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
|
||||
objects = [
|
||||
*run_config.models,
|
||||
*run_config.shields,
|
||||
*run_config.memory_banks,
|
||||
*run_config.datasets,
|
||||
*run_config.scoring_fns,
|
||||
*run_config.eval_tasks,
|
||||
]
|
||||
for obj in objects:
|
||||
await dist_registry.register(obj)
|
||||
|
||||
return impls
|
|
@ -5,14 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Protocol
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
import pydantic
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
RoutableObjectWithProvider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
from llama_stack.providers.utils.kvstore import (
|
||||
|
@ -144,17 +141,16 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|||
|
||||
|
||||
async def create_dist_registry(
|
||||
config: StackRunConfig,
|
||||
metadata_store: Optional[KVStoreConfig],
|
||||
image_name: str,
|
||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||
# instantiate kvstore for storing and retrieving distribution metadata
|
||||
if config.metadata_store:
|
||||
dist_kvstore = await kvstore_impl(config.metadata_store)
|
||||
if metadata_store:
|
||||
dist_kvstore = await kvstore_impl(metadata_store)
|
||||
else:
|
||||
dist_kvstore = await kvstore_impl(
|
||||
SqliteKVStoreConfig(
|
||||
db_path=(
|
||||
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
|
||||
).as_posix()
|
||||
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ import os
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.inference.meta_reference import (
|
||||
MetaReferenceInferenceConfig,
|
||||
|
@ -159,13 +161,13 @@ async def inference_stack(request, inference_model):
|
|||
[Api.inference],
|
||||
{"inference": inference_fixture.providers},
|
||||
inference_fixture.provider_data,
|
||||
)
|
||||
|
||||
provider_id = inference_fixture.providers[0].provider_id
|
||||
print(f"Registering model {inference_model} with provider {provider_id}")
|
||||
await impls[Api.models].register_model(
|
||||
model_id=inference_model,
|
||||
provider_id=provider_id,
|
||||
models=[
|
||||
Model(
|
||||
identifier=inference_model,
|
||||
provider_resource_id=inference_model,
|
||||
provider_id=inference_fixture.providers[0].provider_id,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return (impls[Api.inference], impls[Api.models])
|
||||
|
|
|
@ -17,29 +17,38 @@ from llama_stack.distribution.build import print_pip_install_help
|
|||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.store import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
from llama_stack.distribution.stack import construct_stack
|
||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||
|
||||
|
||||
async def resolve_impls_for_test_v2(
|
||||
apis: List[Api],
|
||||
providers: Dict[str, List[Provider]],
|
||||
provider_data: Optional[Dict[str, Any]] = None,
|
||||
models: Optional[List[Model]] = None,
|
||||
shields: Optional[List[Shield]] = None,
|
||||
memory_banks: Optional[List[MemoryBank]] = None,
|
||||
datasets: Optional[List[Dataset]] = None,
|
||||
scoring_fns: Optional[List[ScoringFn]] = None,
|
||||
eval_tasks: Optional[List[EvalTask]] = None,
|
||||
):
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=apis,
|
||||
providers=providers,
|
||||
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
|
||||
models=models or [],
|
||||
shields=shields or [],
|
||||
memory_banks=memory_banks or [],
|
||||
datasets=datasets or [],
|
||||
scoring_fns=scoring_fns or [],
|
||||
eval_tasks=eval_tasks or [],
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
|
||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||
try:
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
impls = await construct_stack(run_config)
|
||||
except ModuleNotFoundError as e:
|
||||
print_pip_install_help(providers)
|
||||
raise e
|
||||
|
|
|
@ -7,7 +7,9 @@
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.shields import ShieldType
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.apis.shields import Shield, ShieldType
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||
|
@ -96,32 +98,29 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
if safety_fixture.provider_data:
|
||||
provider_data.update(safety_fixture.provider_data)
|
||||
|
||||
shield_provider_type = safety_fixture.providers[0].provider_type
|
||||
shield = get_shield(
|
||||
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
|
||||
)
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.safety, Api.shields, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
Model(
|
||||
identifier=inference_model,
|
||||
provider_id=inference_fixture.providers[0].provider_id,
|
||||
provider_resource_id=inference_model,
|
||||
)
|
||||
],
|
||||
shields=[shield],
|
||||
)
|
||||
|
||||
safety_impl = impls[Api.safety]
|
||||
shields_impl = impls[Api.shields]
|
||||
|
||||
# Register the appropriate shield based on provider type
|
||||
provider_type = safety_fixture.providers[0].provider_type
|
||||
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
|
||||
|
||||
provider_id = inference_fixture.providers[0].provider_id
|
||||
print(f"Registering model {inference_model} with provider {provider_id}")
|
||||
await impls[Api.models].register_model(
|
||||
model_id=inference_model,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return safety_impl, shields_impl, shield
|
||||
return impls[Api.safety], impls[Api.shields], shield
|
||||
|
||||
|
||||
async def create_and_register_shield(
|
||||
provider_type: str, safety_model: str, shields_impl
|
||||
):
|
||||
def get_shield(provider_type: str, provider_id: str, safety_model: str):
|
||||
shield_config = {}
|
||||
shield_type = ShieldType.llama_guard
|
||||
identifier = "llama_guard"
|
||||
|
@ -134,8 +133,10 @@ async def create_and_register_shield(
|
|||
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
||||
shield_type = ShieldType.generic_content_shield
|
||||
|
||||
return await shields_impl.register_shield(
|
||||
shield_id=identifier,
|
||||
return Shield(
|
||||
identifier=identifier,
|
||||
shield_type=shield_type,
|
||||
params=shield_config,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=identifier,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue