From 38257a9cbe7b53f93c8ef4808a8228dbc5988a10 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 11 Nov 2024 22:08:51 -0800 Subject: [PATCH] Allow specifying resources in StackRunConfig --- docs/openapi_generator/generate.py | 43 +--------- llama_stack/apis/resource.py | 3 + llama_stack/distribution/datatypes.py | 8 ++ llama_stack/distribution/server/server.py | 18 ++--- llama_stack/distribution/stack.py | 79 +++++++++++++++++++ llama_stack/distribution/store/registry.py | 18 ++--- .../providers/tests/inference/fixtures.py | 16 ++-- llama_stack/providers/tests/resolver.py | 25 ++++-- .../providers/tests/safety/fixtures.py | 43 +++++----- 9 files changed, 151 insertions(+), 102 deletions(-) create mode 100644 llama_stack/distribution/stack.py diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index dbfc90452..c41e3d003 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -31,48 +31,7 @@ from .strong_typing.schema import json_schema_type schema_utils.json_schema_type = json_schema_type -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.eval import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.batch_inference import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.telemetry import * # noqa: F403 -from llama_stack.apis.post_training import * # noqa: F403 -from llama_stack.apis.synthetic_data_generation import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.inspect import * # noqa: F403 -from llama_stack.apis.eval_tasks import * # noqa: F403 - - -class LlamaStack( - MemoryBanks, - Inference, - BatchInference, - Agents, - Safety, - SyntheticDataGeneration, - Datasets, - Telemetry, - PostTraining, - Memory, - Eval, - EvalTasks, - Scoring, - ScoringFunctions, - DatasetIO, - Models, - Shields, - Inspect, -): - pass +from llama_stack.distribution.stack import LlamaStack # TODO: this should be fixed in the generator itself so it reads appropriate annotations diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index c386311cc..0e488190b 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -22,6 +22,9 @@ class ResourceType(Enum): class Resource(BaseModel): """Base class for all Llama Stack resources""" + # TODO: I think we need to move these into the child classes + # and make them `model_id`, `shield_id`, etc. because otherwise + # the config file has these confusing generic names in there identifier: str = Field( description="Unique identifier for this resource in llama stack" ) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index d0888b981..2cba5b052 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -151,6 +151,14 @@ Configuration for the persistence store used by the distribution registry. If no a default SQLite store will be used.""", ) + # registry of "resources" in the distribution + models: List[Model] = Field(default_factory=list) + shields: List[Shield] = Field(default_factory=list) + memory_banks: List[MemoryBank] = Field(default_factory=list) + datasets: List[Dataset] = Field(default_factory=list) + scoring_fns: List[ScoringFn] = Field(default_factory=list) + eval_tasks: List[EvalTask] = Field(default_factory=list) + class BuildConfig(BaseModel): version: str = LLAMA_STACK_BUILD_CONFIG_VERSION diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 9193583e1..bb57e2cc8 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -27,12 +27,7 @@ from pydantic import BaseModel, ValidationError from termcolor import cprint from typing_extensions import Annotated -from llama_stack.distribution.distribution import ( - builtin_automatically_routed_apis, - get_provider_registry, -) - -from llama_stack.distribution.store.registry import create_dist_registry +from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -42,14 +37,15 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls +from llama_stack.distribution.resolver import InvalidProviderError +from llama_stack.distribution.stack import construct_stack from .endpoints import get_all_api_endpoints def create_sse_event(data: Any) -> str: if isinstance(data, BaseModel): - data = data.json() + data = data.model_dump_json() else: data = json.dumps(data) @@ -281,12 +277,8 @@ def main( app = FastAPI() - dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config)) - try: - impls = asyncio.run( - resolve_impls(config, get_provider_registry(), dist_registry) - ) + impls = asyncio.run(construct_stack(config)) except InvalidProviderError: sys.exit(1) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py new file mode 100644 index 000000000..15bb213c5 --- /dev/null +++ b/llama_stack/distribution/stack.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.eval import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.batch_inference import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.telemetry import * # noqa: F403 +from llama_stack.apis.post_training import * # noqa: F403 +from llama_stack.apis.synthetic_data_generation import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.inspect import * # noqa: F403 +from llama_stack.apis.eval_tasks import * # noqa: F403 + +from llama_stack.distribution.datatypes import StackRunConfig +from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.store.registry import create_dist_registry +from llama_stack.providers.datatypes import Api + + +class LlamaStack( + MemoryBanks, + Inference, + BatchInference, + Agents, + Safety, + SyntheticDataGeneration, + Datasets, + Telemetry, + PostTraining, + Memory, + Eval, + EvalTasks, + Scoring, + ScoringFunctions, + DatasetIO, + Models, + Shields, + Inspect, +): + pass + + +# Produces a stack of providers for the given run config. Not all APIs may be +# asked for in the run config. +async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: + dist_registry, _ = await create_dist_registry( + run_config.metadata_store, run_config.image_name + ) + + impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) + + objects = [ + *run_config.models, + *run_config.shields, + *run_config.memory_banks, + *run_config.datasets, + *run_config.scoring_fns, + *run_config.eval_tasks, + ] + for obj in objects: + await dist_registry.register(obj) + + return impls diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 971ffabc6..6115ea1b3 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -5,14 +5,11 @@ # the root directory of this source tree. import json -from typing import Dict, List, Protocol +from typing import Dict, List, Optional, Protocol import pydantic -from llama_stack.distribution.datatypes import ( - RoutableObjectWithProvider, - StackRunConfig, -) +from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.providers.utils.kvstore import ( @@ -144,17 +141,16 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def create_dist_registry( - config: StackRunConfig, + metadata_store: Optional[KVStoreConfig], + image_name: str, ) -> tuple[CachedDiskDistributionRegistry, KVStore]: # instantiate kvstore for storing and retrieving distribution metadata - if config.metadata_store: - dist_kvstore = await kvstore_impl(config.metadata_store) + if metadata_store: + dist_kvstore = await kvstore_impl(metadata_store) else: dist_kvstore = await kvstore_impl( SqliteKVStoreConfig( - db_path=( - DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" - ).as_posix() + db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix() ) ) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index d91337998..fe91c6e03 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -9,6 +9,8 @@ import os import pytest import pytest_asyncio +from llama_stack.apis.models import Model + from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, @@ -159,13 +161,13 @@ async def inference_stack(request, inference_model): [Api.inference], {"inference": inference_fixture.providers}, inference_fixture.provider_data, - ) - - provider_id = inference_fixture.providers[0].provider_id - print(f"Registering model {inference_model} with provider {provider_id}") - await impls[Api.models].register_model( - model_id=inference_model, - provider_id=provider_id, + models=[ + Model( + identifier=inference_model, + provider_resource_id=inference_model, + provider_id=inference_fixture.providers[0].provider_id, + ) + ], ) return (impls[Api.inference], impls[Api.models]) diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 09d879c80..1353fc71b 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -17,29 +17,38 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls -from llama_stack.distribution.store import CachedDiskDistributionRegistry -from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig +from llama_stack.distribution.stack import construct_stack +from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig async def resolve_impls_for_test_v2( apis: List[Api], providers: Dict[str, List[Provider]], provider_data: Optional[Dict[str, Any]] = None, + models: Optional[List[Model]] = None, + shields: Optional[List[Shield]] = None, + memory_banks: Optional[List[MemoryBank]] = None, + datasets: Optional[List[Dataset]] = None, + scoring_fns: Optional[List[ScoringFn]] = None, + eval_tasks: Optional[List[EvalTask]] = None, ): + sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") run_config = dict( built_at=datetime.now(), image_name="test-fixture", apis=apis, providers=providers, + metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name), + models=models or [], + shields=shields or [], + memory_banks=memory_banks or [], + datasets=datasets or [], + scoring_fns=scoring_fns or [], + eval_tasks=eval_tasks or [], ) run_config = parse_and_maybe_upgrade_config(run_config) - - sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") - dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name)) - dist_registry = CachedDiskDistributionRegistry(dist_kvstore) try: - impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) + impls = await construct_stack(run_config) except ModuleNotFoundError as e: print_pip_install_help(providers) raise e diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 10a6460cb..942e6c116 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -7,7 +7,9 @@ import pytest import pytest_asyncio -from llama_stack.apis.shields import ShieldType +from llama_stack.apis.models import Model + +from llama_stack.apis.shields import Shield, ShieldType from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig @@ -96,32 +98,29 @@ async def safety_stack(inference_model, safety_model, request): if safety_fixture.provider_data: provider_data.update(safety_fixture.provider_data) + shield_provider_type = safety_fixture.providers[0].provider_type + shield = get_shield( + shield_provider_type, safety_fixture.providers[0].provider_id, safety_model + ) + impls = await resolve_impls_for_test_v2( [Api.safety, Api.shields, Api.inference], providers, provider_data, + models=[ + Model( + identifier=inference_model, + provider_id=inference_fixture.providers[0].provider_id, + provider_resource_id=inference_model, + ) + ], + shields=[shield], ) - safety_impl = impls[Api.safety] - shields_impl = impls[Api.shields] - - # Register the appropriate shield based on provider type - provider_type = safety_fixture.providers[0].provider_type - shield = await create_and_register_shield(provider_type, safety_model, shields_impl) - - provider_id = inference_fixture.providers[0].provider_id - print(f"Registering model {inference_model} with provider {provider_id}") - await impls[Api.models].register_model( - model_id=inference_model, - provider_id=provider_id, - ) - - return safety_impl, shields_impl, shield + return impls[Api.safety], impls[Api.shields], shield -async def create_and_register_shield( - provider_type: str, safety_model: str, shields_impl -): +def get_shield(provider_type: str, provider_id: str, safety_model: str): shield_config = {} shield_type = ShieldType.llama_guard identifier = "llama_guard" @@ -134,8 +133,10 @@ async def create_and_register_shield( shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") shield_type = ShieldType.generic_content_shield - return await shields_impl.register_shield( - shield_id=identifier, + return Shield( + identifier=identifier, shield_type=shield_type, params=shield_config, + provider_id=provider_id, + provider_resource_id=identifier, )