Allow specifying resources in StackRunConfig

This commit is contained in:
Ashwin Bharambe 2024-11-11 22:08:51 -08:00
parent 8035fa1869
commit 38257a9cbe
9 changed files with 151 additions and 102 deletions

View file

@ -31,48 +31,7 @@ from .strong_typing.schema import json_schema_type
schema_utils.json_schema_type = json_schema_type
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
class LlamaStack(
MemoryBanks,
Inference,
BatchInference,
Agents,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
Memory,
Eval,
EvalTasks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,
):
pass
from llama_stack.distribution.stack import LlamaStack
# TODO: this should be fixed in the generator itself so it reads appropriate annotations

View file

@ -22,6 +22,9 @@ class ResourceType(Enum):
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
# TODO: I think we need to move these into the child classes
# and make them `model_id`, `shield_id`, etc. because otherwise
# the config file has these confusing generic names in there
identifier: str = Field(
description="Unique identifier for this resource in llama stack"
)

View file

@ -151,6 +151,14 @@ Configuration for the persistence store used by the distribution registry. If no
a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: List[Model] = Field(default_factory=list)
shields: List[Shield] = Field(default_factory=list)
memory_banks: List[MemoryBank] = Field(default_factory=list)
datasets: List[Dataset] = Field(default_factory=list)
scoring_fns: List[ScoringFn] = Field(default_factory=list)
eval_tasks: List[EvalTask] = Field(default_factory=list)
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -27,12 +27,7 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -42,14 +37,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import construct_stack
from .endpoints import get_all_api_endpoints
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
data = data.model_dump_json()
else:
data = json.dumps(data)
@ -281,12 +277,8 @@ def main(
app = FastAPI()
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
try:
impls = asyncio.run(
resolve_impls(config, get_provider_registry(), dist_registry)
)
impls = asyncio.run(construct_stack(config))
except InvalidProviderError:
sys.exit(1)

View file

@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
class LlamaStack(
MemoryBanks,
Inference,
BatchInference,
Agents,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
Memory,
Eval,
EvalTasks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,
):
pass
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
objects = [
*run_config.models,
*run_config.shields,
*run_config.memory_banks,
*run_config.datasets,
*run_config.scoring_fns,
*run_config.eval_tasks,
]
for obj in objects:
await dist_registry.register(obj)
return impls

View file

@ -5,14 +5,11 @@
# the root directory of this source tree.
import json
from typing import Dict, List, Protocol
from typing import Dict, List, Optional, Protocol
import pydantic
from llama_stack.distribution.datatypes import (
RoutableObjectWithProvider,
StackRunConfig,
)
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import (
@ -144,17 +141,16 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def create_dist_registry(
config: StackRunConfig,
metadata_store: Optional[KVStoreConfig],
image_name: str,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = await kvstore_impl(config.metadata_store)
if metadata_store:
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
)
)

View file

@ -9,6 +9,8 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import Model
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
@ -159,13 +161,13 @@ async def inference_stack(request, inference_model):
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
models=[
Model(
identifier=inference_model,
provider_resource_id=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
)
],
)
return (impls[Api.inference], impls[Api.models])

View file

@ -17,29 +17,38 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
async def resolve_impls_for_test_v2(
apis: List[Api],
providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None,
models: Optional[List[Model]] = None,
shields: Optional[List[Shield]] = None,
memory_banks: Optional[List[MemoryBank]] = None,
datasets: Optional[List[Dataset]] = None,
scoring_fns: Optional[List[ScoringFn]] = None,
eval_tasks: Optional[List[EvalTask]] = None,
):
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=apis,
providers=providers,
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
models=models or [],
shields=shields or [],
memory_banks=memory_banks or [],
datasets=datasets or [],
scoring_fns=scoring_fns or [],
eval_tasks=eval_tasks or [],
)
run_config = parse_and_maybe_upgrade_config(run_config)
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
try:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
impls = await construct_stack(run_config)
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e

View file

@ -7,7 +7,9 @@
import pytest
import pytest_asyncio
from llama_stack.apis.shields import ShieldType
from llama_stack.apis.models import Model
from llama_stack.apis.shields import Shield, ShieldType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
@ -96,32 +98,29 @@ async def safety_stack(inference_model, safety_model, request):
if safety_fixture.provider_data:
provider_data.update(safety_fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type
shield = get_shield(
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
)
impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference],
providers,
provider_data,
models=[
Model(
identifier=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
provider_resource_id=inference_model,
)
],
shields=[shield],
)
safety_impl = impls[Api.safety]
shields_impl = impls[Api.shields]
# Register the appropriate shield based on provider type
provider_type = safety_fixture.providers[0].provider_type
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
)
return safety_impl, shields_impl, shield
return impls[Api.safety], impls[Api.shields], shield
async def create_and_register_shield(
provider_type: str, safety_model: str, shields_impl
):
def get_shield(provider_type: str, provider_id: str, safety_model: str):
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
@ -134,8 +133,10 @@ async def create_and_register_shield(
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
shield_type = ShieldType.generic_content_shield
return await shields_impl.register_shield(
shield_id=identifier,
return Shield(
identifier=identifier,
shield_type=shield_type,
params=shield_config,
provider_id=provider_id,
provider_resource_id=identifier,
)