diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index dbfc90452..c41e3d003 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -31,48 +31,7 @@ from .strong_typing.schema import json_schema_type schema_utils.json_schema_type = json_schema_type -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.eval import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.batch_inference import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.telemetry import * # noqa: F403 -from llama_stack.apis.post_training import * # noqa: F403 -from llama_stack.apis.synthetic_data_generation import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.inspect import * # noqa: F403 -from llama_stack.apis.eval_tasks import * # noqa: F403 - - -class LlamaStack( - MemoryBanks, - Inference, - BatchInference, - Agents, - Safety, - SyntheticDataGeneration, - Datasets, - Telemetry, - PostTraining, - Memory, - Eval, - EvalTasks, - Scoring, - ScoringFunctions, - DatasetIO, - Models, - Shields, - Inspect, -): - pass +from llama_stack.distribution.stack import LlamaStack # TODO: this should be fixed in the generator itself so it reads appropriate annotations diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index c386311cc..0e488190b 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -22,6 +22,9 @@ class ResourceType(Enum): class Resource(BaseModel): """Base class for all Llama Stack resources""" + # TODO: I think we need to move these into the child classes + # and make them `model_id`, `shield_id`, etc. because otherwise + # the config file has these confusing generic names in there identifier: str = Field( description="Unique identifier for this resource in llama stack" ) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index d0888b981..2cba5b052 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -151,6 +151,14 @@ Configuration for the persistence store used by the distribution registry. If no a default SQLite store will be used.""", ) + # registry of "resources" in the distribution + models: List[Model] = Field(default_factory=list) + shields: List[Shield] = Field(default_factory=list) + memory_banks: List[MemoryBank] = Field(default_factory=list) + datasets: List[Dataset] = Field(default_factory=list) + scoring_fns: List[ScoringFn] = Field(default_factory=list) + eval_tasks: List[EvalTask] = Field(default_factory=list) + class BuildConfig(BaseModel): version: str = LLAMA_STACK_BUILD_CONFIG_VERSION diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 9193583e1..bb57e2cc8 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -27,12 +27,7 @@ from pydantic import BaseModel, ValidationError from termcolor import cprint from typing_extensions import Annotated -from llama_stack.distribution.distribution import ( - builtin_automatically_routed_apis, - get_provider_registry, -) - -from llama_stack.distribution.store.registry import create_dist_registry +from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -42,14 +37,15 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls +from llama_stack.distribution.resolver import InvalidProviderError +from llama_stack.distribution.stack import construct_stack from .endpoints import get_all_api_endpoints def create_sse_event(data: Any) -> str: if isinstance(data, BaseModel): - data = data.json() + data = data.model_dump_json() else: data = json.dumps(data) @@ -281,12 +277,8 @@ def main( app = FastAPI() - dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config)) - try: - impls = asyncio.run( - resolve_impls(config, get_provider_registry(), dist_registry) - ) + impls = asyncio.run(construct_stack(config)) except InvalidProviderError: sys.exit(1) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py new file mode 100644 index 000000000..7fe7d3ca7 --- /dev/null +++ b/llama_stack/distribution/stack.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from termcolor import colored + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.eval import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.batch_inference import * # noqa: F403 +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.telemetry import * # noqa: F403 +from llama_stack.apis.post_training import * # noqa: F403 +from llama_stack.apis.synthetic_data_generation import * # noqa: F403 +from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.inspect import * # noqa: F403 +from llama_stack.apis.eval_tasks import * # noqa: F403 + +from llama_stack.distribution.datatypes import StackRunConfig +from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.store.registry import create_dist_registry +from llama_stack.providers.datatypes import Api + + +class LlamaStack( + MemoryBanks, + Inference, + BatchInference, + Agents, + Safety, + SyntheticDataGeneration, + Datasets, + Telemetry, + PostTraining, + Memory, + Eval, + EvalTasks, + Scoring, + ScoringFunctions, + DatasetIO, + Models, + Shields, + Inspect, +): + pass + + +# Produces a stack of providers for the given run config. Not all APIs may be +# asked for in the run config. +async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: + dist_registry, _ = await create_dist_registry( + run_config.metadata_store, run_config.image_name + ) + + impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) + + objects = [ + *run_config.models, + *run_config.shields, + *run_config.memory_banks, + *run_config.datasets, + *run_config.scoring_fns, + *run_config.eval_tasks, + ] + for obj in objects: + await dist_registry.register(obj) + + resources = [ + ("models", Api.models), + ("shields", Api.shields), + ("memory_banks", Api.memory_banks), + ("datasets", Api.datasets), + ("scoring_fns", Api.scoring_functions), + ("eval_tasks", Api.eval_tasks), + ] + for rsrc, api in resources: + if api not in impls: + continue + + method = getattr(impls[api], f"list_{api.value}") + for obj in await method(): + print( + f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", + ) + + print("") + return impls diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 971ffabc6..6115ea1b3 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -5,14 +5,11 @@ # the root directory of this source tree. import json -from typing import Dict, List, Protocol +from typing import Dict, List, Optional, Protocol import pydantic -from llama_stack.distribution.datatypes import ( - RoutableObjectWithProvider, - StackRunConfig, -) +from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.providers.utils.kvstore import ( @@ -144,17 +141,16 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def create_dist_registry( - config: StackRunConfig, + metadata_store: Optional[KVStoreConfig], + image_name: str, ) -> tuple[CachedDiskDistributionRegistry, KVStore]: # instantiate kvstore for storing and retrieving distribution metadata - if config.metadata_store: - dist_kvstore = await kvstore_impl(config.metadata_store) + if metadata_store: + dist_kvstore = await kvstore_impl(metadata_store) else: dist_kvstore = await kvstore_impl( SqliteKVStoreConfig( - db_path=( - DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" - ).as_posix() + db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix() ) ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index a36a2c24f..2b3d0dbc4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -641,12 +641,13 @@ class ChatAgent(ShieldRunnerMixin): if session_info.memory_bank_id is None: bank_id = f"memory_bank_{session_id}" - memory_bank = VectorMemoryBank( - identifier=bank_id, - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, + await self.memory_banks_api.register_memory_bank( + memory_bank_id=bank_id, + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + ), ) - await self.memory_banks_api.register_memory_bank(memory_bank) await self.storage.add_memory_bank_to_session(session_id, bank_id) else: bank_id = session_info.memory_bank_id diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index c2e1261f7..aa3910b39 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -19,7 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "meta_reference", "safety": "llama_guard", - "memory": "meta_reference", + "memory": "faiss", "agents": "meta_reference", }, id="meta_reference", @@ -29,7 +29,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ { "inference": "ollama", "safety": "llama_guard", - "memory": "meta_reference", + "memory": "faiss", "agents": "meta_reference", }, id="ollama", @@ -40,7 +40,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "inference": "together", "safety": "llama_guard", # make this work with Weaviate which is what the together distro supports - "memory": "meta_reference", + "memory": "faiss", "agents": "meta_reference", }, id="together", diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 8330e2604..6ee17ff1f 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -9,6 +9,7 @@ import tempfile import pytest import pytest_asyncio +from llama_stack.apis.models import Model from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.agents.meta_reference import ( @@ -17,8 +18,18 @@ from llama_stack.providers.inline.agents.meta_reference import ( from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig - from ..conftest import ProviderFixture, remote_stack_fixture +from ..safety.fixtures import get_shield_to_register + + +def pick_inference_model(inference_model): + # This is not entirely satisfactory. The fixture `inference_model` can correspond to + # multiple models when you need to run a safety model in addition to normal agent + # inference model. We filter off the safety model by looking for "Llama-Guard" + if isinstance(inference_model, list): + inference_model = next(m for m in inference_model if "Llama-Guard" not in m) + assert inference_model is not None + return inference_model @pytest.fixture(scope="session") @@ -49,7 +60,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") -async def agents_stack(request): +async def agents_stack(request, inference_model, safety_model): fixture_dict = request.param providers = {} @@ -60,9 +71,28 @@ async def agents_stack(request): if fixture.provider_data: provider_data.update(fixture.provider_data) + inf_provider_id = providers["inference"][0].provider_id + safety_provider_id = providers["safety"][0].provider_id + + shield = get_shield_to_register( + providers["safety"][0].provider_type, safety_provider_id, safety_model + ) + + inference_models = ( + inference_model if isinstance(inference_model, list) else [inference_model] + ) impls = await resolve_impls_for_test_v2( [Api.agents, Api.inference, Api.safety, Api.memory], providers, provider_data, + models=[ + Model( + identifier=model, + provider_id=inf_provider_id, + provider_resource_id=model, + ) + for model in inference_models + ], + shields=[shield], ) return impls[Api.agents], impls[Api.memory] diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 5b1fe202a..b3f3dc31c 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -16,15 +16,12 @@ from llama_stack.providers.datatypes import * # noqa: F403 # pytest -v -s llama_stack/providers/tests/agents/test_agents.py # -m "meta_reference" +from .fixtures import pick_inference_model + @pytest.fixture def common_params(inference_model): - # This is not entirely satisfactory. The fixture `inference_model` can correspond to - # multiple models when you need to run a safety model in addition to normal agent - # inference model. We filter off the safety model by looking for "Llama-Guard" - if isinstance(inference_model, list): - inference_model = next(m for m in inference_model if "Llama-Guard" not in m) - assert inference_model is not None + inference_model = pick_inference_model(inference_model) return dict( model=inference_model, diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index d91337998..fe91c6e03 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -9,6 +9,8 @@ import os import pytest import pytest_asyncio +from llama_stack.apis.models import Model + from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, @@ -159,13 +161,13 @@ async def inference_stack(request, inference_model): [Api.inference], {"inference": inference_fixture.providers}, inference_fixture.provider_data, - ) - - provider_id = inference_fixture.providers[0].provider_id - print(f"Registering model {inference_model} with provider {provider_id}") - await impls[Api.models].register_model( - model_id=inference_model, - provider_id=provider_id, + models=[ + Model( + identifier=inference_model, + provider_resource_id=inference_model, + provider_id=inference_fixture.providers[0].provider_id, + ) + ], ) return (impls[Api.inference], impls[Api.models]) diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 482049045..456e354b2 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -26,13 +26,13 @@ def memory_remote() -> ProviderFixture: @pytest.fixture(scope="session") -def memory_meta_reference() -> ProviderFixture: +def memory_faiss() -> ProviderFixture: temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") return ProviderFixture( providers=[ Provider( - provider_id="meta-reference", - provider_type="meta-reference", + provider_id="faiss", + provider_type="inline::faiss", config=FaissImplConfig( kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(), ).model_dump(), @@ -93,7 +93,7 @@ def memory_chroma() -> ProviderFixture: ) -MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"] +MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index a1befa6b0..24cef8a24 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -44,7 +44,6 @@ def sample_documents(): async def register_memory_bank(banks_impl: MemoryBanks): - return await banks_impl.register_memory_bank( memory_bank_id="test_bank", params=VectorMemoryBankParams( @@ -71,7 +70,7 @@ class TestMemory: # but so far we don't have an unregister API unfortunately, so be careful _, banks_impl = memory_stack - bank = await banks_impl.register_memory_bank( + await banks_impl.register_memory_bank( memory_bank_id="test_bank_no_provider", params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 09d879c80..1353fc71b 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -17,29 +17,38 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls -from llama_stack.distribution.store import CachedDiskDistributionRegistry -from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig +from llama_stack.distribution.stack import construct_stack +from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig async def resolve_impls_for_test_v2( apis: List[Api], providers: Dict[str, List[Provider]], provider_data: Optional[Dict[str, Any]] = None, + models: Optional[List[Model]] = None, + shields: Optional[List[Shield]] = None, + memory_banks: Optional[List[MemoryBank]] = None, + datasets: Optional[List[Dataset]] = None, + scoring_fns: Optional[List[ScoringFn]] = None, + eval_tasks: Optional[List[EvalTask]] = None, ): + sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") run_config = dict( built_at=datetime.now(), image_name="test-fixture", apis=apis, providers=providers, + metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name), + models=models or [], + shields=shields or [], + memory_banks=memory_banks or [], + datasets=datasets or [], + scoring_fns=scoring_fns or [], + eval_tasks=eval_tasks or [], ) run_config = parse_and_maybe_upgrade_config(run_config) - - sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") - dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name)) - dist_registry = CachedDiskDistributionRegistry(dist_kvstore) try: - impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) + impls = await construct_stack(run_config) except ModuleNotFoundError as e: print_pip_install_help(providers) raise e diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 10a6460cb..5e553830c 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -7,7 +7,9 @@ import pytest import pytest_asyncio -from llama_stack.apis.shields import ShieldType +from llama_stack.apis.models import Model + +from llama_stack.apis.shields import Shield, ShieldType from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig @@ -96,32 +98,29 @@ async def safety_stack(inference_model, safety_model, request): if safety_fixture.provider_data: provider_data.update(safety_fixture.provider_data) + shield_provider_type = safety_fixture.providers[0].provider_type + shield = get_shield_to_register( + shield_provider_type, safety_fixture.providers[0].provider_id, safety_model + ) + impls = await resolve_impls_for_test_v2( [Api.safety, Api.shields, Api.inference], providers, provider_data, + models=[ + Model( + identifier=inference_model, + provider_id=inference_fixture.providers[0].provider_id, + provider_resource_id=inference_model, + ) + ], + shields=[shield], ) - safety_impl = impls[Api.safety] - shields_impl = impls[Api.shields] - - # Register the appropriate shield based on provider type - provider_type = safety_fixture.providers[0].provider_type - shield = await create_and_register_shield(provider_type, safety_model, shields_impl) - - provider_id = inference_fixture.providers[0].provider_id - print(f"Registering model {inference_model} with provider {provider_id}") - await impls[Api.models].register_model( - model_id=inference_model, - provider_id=provider_id, - ) - - return safety_impl, shields_impl, shield + return impls[Api.safety], impls[Api.shields], shield -async def create_and_register_shield( - provider_type: str, safety_model: str, shields_impl -): +def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str): shield_config = {} shield_type = ShieldType.llama_guard identifier = "llama_guard" @@ -134,8 +133,10 @@ async def create_and_register_shield( shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") shield_type = ShieldType.generic_content_shield - return await shields_impl.register_shield( - shield_id=identifier, + return Shield( + identifier=identifier, shield_type=shield_type, params=shield_config, + provider_id=provider_id, + provider_resource_id=identifier, )