From f3923e3f0be79295af73fe16245a52b08515f148 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 5 Oct 2024 08:41:36 -0700 Subject: [PATCH] Redo the { models, shields, memory_banks } typeset --- llama_stack/apis/memory/memory.py | 78 ------ llama_stack/apis/memory_banks/memory_banks.py | 66 ++++- llama_stack/apis/models/models.py | 27 +- llama_stack/apis/shields/shields.py | 36 ++- llama_stack/cli/stack/configure.py | 8 +- llama_stack/cli/tests/test_stack_build.py | 209 ++++++++------- llama_stack/distribution/configure.py | 103 +++++++- llama_stack/distribution/datatypes.py | 74 +++--- llama_stack/distribution/inspect.py | 10 +- llama_stack/distribution/resolver.py | 239 ++++++++++-------- llama_stack/distribution/routers/__init__.py | 21 +- .../distribution/routers/routing_tables.py | 155 +++++------- llama_stack/distribution/server/server.py | 8 +- llama_stack/providers/datatypes.py | 4 + .../impls/meta_reference/safety/config.py | 4 +- 15 files changed, 588 insertions(+), 454 deletions(-) diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index 261dd93ee..8ac4a08a6 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -13,7 +13,6 @@ from typing import List, Optional, Protocol from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from typing_extensions import Annotated from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -26,44 +25,6 @@ class MemoryBankDocument(BaseModel): metadata: Dict[str, Any] = Field(default_factory=dict) -@json_schema_type -class MemoryBankType(Enum): - vector = "vector" - keyvalue = "keyvalue" - keyword = "keyword" - graph = "graph" - - -class VectorMemoryBankConfig(BaseModel): - type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value - embedding_model: str - chunk_size_in_tokens: int - overlap_size_in_tokens: Optional[int] = None - - -class KeyValueMemoryBankConfig(BaseModel): - type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value - - -class KeywordMemoryBankConfig(BaseModel): - type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value - - -class GraphMemoryBankConfig(BaseModel): - type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value - - -MemoryBankConfig = Annotated[ - Union[ - VectorMemoryBankConfig, - KeyValueMemoryBankConfig, - KeywordMemoryBankConfig, - GraphMemoryBankConfig, - ], - Field(discriminator="type"), -] - - class Chunk(BaseModel): content: InterleavedTextMedia token_count: int @@ -76,46 +37,7 @@ class QueryDocumentsResponse(BaseModel): scores: List[float] -@json_schema_type -class QueryAPI(Protocol): - @webmethod(route="/query_documents") - def query_documents( - self, - query: InterleavedTextMedia, - params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: ... - - -@json_schema_type -class MemoryBank(BaseModel): - bank_id: str - name: str - config: MemoryBankConfig - # if there's a pre-existing (reachable-from-distribution) store which supports QueryAPI - url: Optional[URL] = None - - class Memory(Protocol): - @webmethod(route="/memory/create") - async def create_memory_bank( - self, - name: str, - config: MemoryBankConfig, - url: Optional[URL] = None, - ) -> MemoryBank: ... - - @webmethod(route="/memory/list", method="GET") - async def list_memory_banks(self) -> List[MemoryBank]: ... - - @webmethod(route="/memory/get", method="GET") - async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ... - - @webmethod(route="/memory/drop", method="DELETE") - async def drop_memory_bank( - self, - bank_id: str, - ) -> str: ... - # this will just block now until documents are inserted, but it should # probably return a Job instance which can be polled for completion @webmethod(route="/memory/insert") diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 53ca83e84..d54c3868d 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -4,29 +4,67 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Protocol +from enum import Enum +from typing import List, Literal, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field - -from llama_stack.apis.memory import MemoryBankType - -from llama_stack.distribution.datatypes import GenericProviderConfig +from typing_extensions import Annotated @json_schema_type -class MemoryBankSpec(BaseModel): - bank_type: MemoryBankType - provider_config: GenericProviderConfig = Field( - description="Provider config for the model, including provider_type, and corresponding config. ", - ) +class MemoryBankType(Enum): + vector = "vector" + keyvalue = "keyvalue" + keyword = "keyword" + graph = "graph" + + +class CommonDef(BaseModel): + identifier: str + provider_id: str + + +@json_schema_type +class VectorMemoryBankDef(CommonDef): + type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value + embedding_model: str + chunk_size_in_tokens: int + overlap_size_in_tokens: Optional[int] = None + + +@json_schema_type +class KeyValueMemoryBankDef(CommonDef): + type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value + + +@json_schema_type +class KeywordMemoryBankDef(CommonDef): + type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value + + +@json_schema_type +class GraphMemoryBankDef(CommonDef): + type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value + + +MemoryBankDef = Annotated[ + Union[ + VectorMemoryBankDef, + KeyValueMemoryBankDef, + KeywordMemoryBankDef, + GraphMemoryBankDef, + ], + Field(discriminator="type"), +] class MemoryBanks(Protocol): @webmethod(route="/memory_banks/list", method="GET") - async def list_available_memory_banks(self) -> List[MemoryBankSpec]: ... + async def list_memory_banks(self) -> List[MemoryBankDef]: ... @webmethod(route="/memory_banks/get", method="GET") - async def get_serving_memory_bank( - self, bank_type: MemoryBankType - ) -> Optional[MemoryBankSpec]: ... + async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ... + + @webmethod(route="/memory_banks/register", method="POST") + async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ... diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 2952a8dee..21dd17ca2 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -6,27 +6,32 @@ from typing import List, Optional, Protocol -from llama_models.llama3.api.datatypes import Model - from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from llama_stack.distribution.datatypes import GenericProviderConfig - @json_schema_type -class ModelServingSpec(BaseModel): - llama_model: Model = Field( - description="All metadatas associated with llama model (defined in llama_models.models.sku_list).", +class ModelDef(BaseModel): + identifier: str = Field( + description="A unique identifier for the model type", ) - provider_config: GenericProviderConfig = Field( - description="Provider config for the model, including provider_type, and corresponding config. ", + llama_model: str = Field( + description="Pointer to the core Llama family model", ) + provider_id: str = Field( + description="The provider instance which serves this model" + ) + # For now, we are only supporting core llama models but as soon as finetuned + # and other custom models (for example various quantizations) are allowed, there + # will be more metadata fields here class Models(Protocol): @webmethod(route="/models/list", method="GET") - async def list_models(self) -> List[ModelServingSpec]: ... + async def list_models(self) -> List[ModelDef]: ... @webmethod(route="/models/get", method="GET") - async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ... + async def get_model(self, identifier: str) -> Optional[ModelDef]: ... + + @webmethod(route="/models/register", method="POST") + async def register_model(self, model: ModelDef) -> None: ... diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 2b8242263..db507a383 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -4,25 +4,43 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Protocol +from enum import Enum +from typing import Any, Dict, List, Optional, Protocol from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from llama_stack.distribution.datatypes import GenericProviderConfig - @json_schema_type -class ShieldSpec(BaseModel): - shield_type: str - provider_config: GenericProviderConfig = Field( - description="Provider config for the model, including provider_type, and corresponding config. ", +class ShieldType(Enum): + generic_content_shield = "generic_content_shield" + llama_guard = "llama_guard" + code_scanner = "code_scanner" + prompt_guard = "prompt_guard" + + +class ShieldDef(BaseModel): + identifier: str = Field( + description="A unique identifier for the shield type", + ) + provider_id: str = Field( + description="The provider instance which serves this shield" + ) + type: str = Field( + description="The type of shield this is; the value is one of the ShieldType enum" + ) + params: Dict[str, Any] = Field( + default_factory=dict, + description="Any additional parameters needed for this shield", ) class Shields(Protocol): @webmethod(route="/shields/list", method="GET") - async def list_shields(self) -> List[ShieldSpec]: ... + async def list_shields(self) -> List[ShieldDef]: ... @webmethod(route="/shields/get", method="GET") - async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ... + async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: ... + + @webmethod(route="/shields/register", method="POST") + async def register_shield(self, shield: ShieldDef) -> None: ... diff --git a/llama_stack/cli/stack/configure.py b/llama_stack/cli/stack/configure.py index b8940ea49..e1b0aa39f 100644 --- a/llama_stack/cli/stack/configure.py +++ b/llama_stack/cli/stack/configure.py @@ -129,7 +129,10 @@ class StackConfigure(Subcommand): import yaml from termcolor import cprint - from llama_stack.distribution.configure import configure_api_providers + from llama_stack.distribution.configure import ( + configure_api_providers, + parse_and_maybe_upgrade_config, + ) from llama_stack.distribution.utils.serialize import EnumEncoder builds_dir = BUILDS_BASE_DIR / build_config.image_type @@ -145,7 +148,8 @@ class StackConfigure(Subcommand): "yellow", attrs=["bold"], ) - config = StackRunConfig(**yaml.safe_load(run_config_file.read_text())) + config_dict = yaml.safe_load(config_file.read_text()) + config = parse_and_maybe_upgrade_config(config_dict) else: config = StackRunConfig( built_at=datetime.now(), diff --git a/llama_stack/cli/tests/test_stack_build.py b/llama_stack/cli/tests/test_stack_build.py index 8b427a959..b04e80317 100644 --- a/llama_stack/cli/tests/test_stack_build.py +++ b/llama_stack/cli/tests/test_stack_build.py @@ -1,105 +1,142 @@ -from argparse import Namespace -from unittest.mock import MagicMock, patch - import pytest -from llama_stack.distribution.datatypes import BuildConfig -from llama_stack.cli.stack.build import StackBuild - - -# temporary while we make the tests work -pytest.skip(allow_module_level=True) +import yaml +from datetime import datetime +from llama_stack.distribution.configure import ( + parse_and_maybe_upgrade_config, + LLAMA_STACK_RUN_CONFIG_VERSION, +) @pytest.fixture -def stack_build(): - parser = MagicMock() - subparsers = MagicMock() - return StackBuild(subparsers) - - -def test_stack_build_initialization(stack_build): - assert stack_build.parser is not None - assert stack_build.parser.set_defaults.called_once_with( - func=stack_build._run_stack_build_command +def up_to_date_config(): + return yaml.safe_load( + """ + version: {version} + image_name: foo + apis_to_serve: [] + built_at: {built_at} + models: + - identifier: model1 + provider_id: provider1 + llama_model: Llama3.1-8B-Instruct + shields: + - identifier: shield1 + type: llama_guard + provider_id: provider1 + memory_banks: + - identifier: memory1 + type: vector + provider_id: provider1 + embedding_model: all-MiniLM-L6-v2 + chunk_size_in_tokens: 512 + providers: + inference: + - provider_id: provider1 + provider_type: meta-reference + config: {{}} + safety: + - provider_id: provider1 + provider_type: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-1B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + enable_prompt_guard: false + memory: + - provider_id: provider1 + provider_type: meta-reference + config: {{}} + """.format( + version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat() + ) ) -@patch("llama_stack.distribution.build.build_image") -def test_run_stack_build_command_with_config( - mock_build_image, mock_build_config, stack_build -): - args = Namespace( - config="test_config.yaml", - template=None, - list_templates=False, - name=None, - image_type="conda", +@pytest.fixture +def old_config(): + return yaml.safe_load( + """ + image_name: foo + built_at: {built_at} + apis_to_serve: [] + routing_table: + inference: + - provider_type: remote::ollama + config: + host: localhost + port: 11434 + routing_key: Llama3.2-1B-Instruct + - provider_type: meta-reference + config: + model: Llama3.1-8B-Instruct + routing_key: Llama3.1-8B-Instruct + safety: + - routing_key: ["shield1", "shield2"] + provider_type: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-1B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + enable_prompt_guard: false + memory: + - routing_key: vector + provider_type: meta-reference + config: {{}} + api_providers: + telemetry: + provider_type: noop + config: {{}} + """.format(built_at=datetime.now().isoformat()) ) - with patch("builtins.open", MagicMock()): - with patch("yaml.safe_load") as mock_yaml_load: - mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"} - mock_build_config.return_value = MagicMock() - stack_build._run_stack_build_command(args) - - mock_build_config.assert_called_once() - mock_build_image.assert_called_once() +@pytest.fixture +def invalid_config(): + return yaml.safe_load(""" + routing_table: {} + api_providers: {} + """) -@patch("llama_stack.cli.table.print_table") -def test_run_stack_build_command_list_templates(mock_print_table, stack_build): - args = Namespace(list_templates=True) - - stack_build._run_stack_build_command(args) - - mock_print_table.assert_called_once() +def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config): + result = parse_and_maybe_upgrade_config(up_to_date_config) + assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION + assert len(result.models) == 1 + assert len(result.shields) == 1 + assert len(result.memory_banks) == 1 + assert "inference" in result.providers -@patch("prompt_toolkit.prompt") -@patch("llama_stack.distribution.datatypes.BuildConfig") -@patch("llama_stack.distribution.build.build_image") -def test_run_stack_build_command_interactive( - mock_build_image, mock_build_config, mock_prompt, stack_build -): - args = Namespace( - config=None, template=None, list_templates=False, name=None, image_type=None +def test_parse_and_maybe_upgrade_config_old_format(old_config): + result = parse_and_maybe_upgrade_config(old_config) + assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION + assert len(result.models) == 2 + assert len(result.shields) == 2 + assert len(result.memory_banks) == 1 + assert all( + api in result.providers + for api in ["inference", "safety", "memory", "telemetry"] ) + safety_provider = result.providers["safety"][0] + assert safety_provider.provider_type == "meta-reference" + assert "llama_guard_shield" in safety_provider.config - mock_prompt.side_effect = [ - "test_name", - "conda", - "meta-reference", - "test description", - ] - mock_build_config.return_value = MagicMock() + inference_providers = result.providers["inference"] + assert len(inference_providers) == 2 + assert set(x.provider_id for x in inference_providers) == { + "remote::ollama-00", + "meta-reference-01", + } - stack_build._run_stack_build_command(args) - - assert mock_prompt.call_count == 4 - mock_build_config.assert_called_once() - mock_build_image.assert_called_once() + ollama = inference_providers[0] + assert ollama.provider_type == "remote::ollama" + assert ollama.config["port"] == 11434 -@patch("llama_stack.distribution.datatypes.BuildConfig") -@patch("llama_stack.distribution.build.build_image") -def test_run_stack_build_command_with_template( - mock_build_image, mock_build_config, stack_build -): - args = Namespace( - config=None, - template="test_template", - list_templates=False, - name="test_name", - image_type="docker", - ) - - with patch("builtins.open", MagicMock()): - with patch("yaml.safe_load") as mock_yaml_load: - mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"} - mock_build_config.return_value = MagicMock() - - stack_build._run_stack_build_command(args) - - mock_build_config.assert_called_once() - mock_build_image.assert_called_once() +def test_parse_and_maybe_upgrade_config_invalid(invalid_config): + with pytest.raises(ValueError): + parse_and_maybe_upgrade_config(invalid_config) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index d678a2e00..1fdde3092 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -20,7 +20,6 @@ from prompt_toolkit import prompt from prompt_toolkit.validation import Validator from termcolor import cprint -from llama_stack.apis.memory.memory import MemoryBankType from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, get_provider_registry, @@ -177,9 +176,6 @@ def configure_api_providers( ) config.routing_table[api_str] = routing_entries - config.api_providers[api_str] = PlaceholderProviderConfig( - providers=p if isinstance(p, list) else [p] - ) else: config.api_providers[api_str] = GenericProviderConfig( provider_type=p, @@ -189,3 +185,102 @@ def configure_api_providers( print("") return config + + +def upgrade_from_routing_table_to_registry( + config_dict: Dict[str, Any], +) -> Dict[str, Any]: + def get_providers(entries): + return [ + Provider( + provider_id=f"{entry['provider_type']}-{i:02d}", + provider_type=entry["provider_type"], + config=entry["config"], + ) + for i, entry in enumerate(entries) + ] + + providers_by_api = {} + models = [] + shields = [] + memory_banks = [] + + routing_table = config_dict["routing_table"] + for api_str, entries in routing_table.items(): + providers = get_providers(entries) + providers_by_api[api_str] = providers + + if api_str == "inference": + for entry, provider in zip(entries, providers): + key = entry["routing_key"] + keys = key if isinstance(key, list) else [key] + for key in keys: + models.append( + ModelDef( + identifier=key, + provider_id=provider.provider_id, + llama_model=key, + ) + ) + elif api_str == "safety": + for entry, provider in zip(entries, providers): + key = entry["routing_key"] + keys = key if isinstance(key, list) else [key] + for key in keys: + shields.append( + ShieldDef( + identifier=key, + type=ShieldType.llama_guard.value, + provider_id=provider.provider_id, + ) + ) + elif api_str == "memory": + for entry, provider in zip(entries, providers): + key = entry["routing_key"] + keys = key if isinstance(key, list) else [key] + for key in keys: + # we currently only support Vector memory banks so this is OK + memory_banks.append( + VectorMemoryBankDef( + identifier=key, + provider_id=provider.provider_id, + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + ) + ) + config_dict["models"] = models + config_dict["shields"] = shields + config_dict["memory_banks"] = memory_banks + + if "api_providers" in config_dict: + for api_str, provider in config_dict["api_providers"].items(): + if isinstance(provider, dict): + providers_by_api[api_str] = [ + Provider( + provider_id=f"{provider['provider_type']}-00", + provider_type=provider["provider_type"], + config=provider["config"], + ) + ] + + config_dict["providers"] = providers_by_api + + del config_dict["routing_table"] + del config_dict["api_providers"] + + return config_dict + + +def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig: + version = config_dict.get("version", None) + if version == LLAMA_STACK_RUN_CONFIG_VERSION: + return StackRunConfig(**config_dict) + + if "models" not in config_dict: + print("Upgrading config...") + config_dict = upgrade_from_routing_table_to_registry(config_dict) + + config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION + config_dict["built_at"] = datetime.now().isoformat() + + return StackRunConfig(**config_dict) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 09778a761..bccb7d705 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -11,10 +11,13 @@ from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field from llama_stack.providers.datatypes import * # noqa: F403 +from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.memory_banks import * # noqa: F403 -LLAMA_STACK_BUILD_CONFIG_VERSION = "v1" -LLAMA_STACK_RUN_CONFIG_VERSION = "v1" +LLAMA_STACK_BUILD_CONFIG_VERSION = "2" +LLAMA_STACK_RUN_CONFIG_VERSION = "2" RoutingKey = Union[str, List[str]] @@ -29,12 +32,6 @@ class RoutableProviderConfig(GenericProviderConfig): routing_key: RoutingKey -class PlaceholderProviderConfig(BaseModel): - """Placeholder provider config for API whose provider are defined in routing_table""" - - providers: List[str] - - # Example: /inference, /safety class AutoRoutedProviderSpec(ProviderSpec): provider_type: str = "router" @@ -53,18 +50,16 @@ class AutoRoutedProviderSpec(ProviderSpec): # Example: /models, /shields -@json_schema_type class RoutingTableProviderSpec(ProviderSpec): provider_type: str = "routing_table" config_class: str = "" docker_image: Optional[str] = None - inner_specs: List[ProviderSpec] + router_api: Api module: str pip_packages: List[str] = Field(default_factory=list) -@json_schema_type class DistributionSpec(BaseModel): description: Optional[str] = Field( default="", @@ -80,7 +75,12 @@ in the runtime configuration to help route to the correct provider.""", ) -@json_schema_type +class Provider(BaseModel): + provider_id: str + provider_type: str + config: Dict[str, Any] + + class StackRunConfig(BaseModel): version: str = LLAMA_STACK_RUN_CONFIG_VERSION built_at: datetime @@ -105,31 +105,37 @@ this could be just a hash The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", ) - api_providers: Dict[ - str, Union[GenericProviderConfig, PlaceholderProviderConfig] - ] = Field( - description=""" -Provider configurations for each of the APIs provided by this package. -""", - ) - routing_table: Dict[str, List[RoutableProviderConfig]] = Field( - default_factory=dict, - description=""" + providers: Dict[str, List[Provider]] - E.g. The following is a ProviderRoutingEntry for models: - - routing_key: Llama3.1-8B-Instruct - provider_type: meta-reference - config: - model: Llama3.1-8B-Instruct - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - """, - ) + models: List[ModelDef] + memory_banks: List[MemoryBankDef] + shields: List[ShieldDef] + + +# api_providers: Dict[ +# str, Union[GenericProviderConfig, PlaceholderProviderConfig] +# ] = Field( +# description=""" +# Provider configurations for each of the APIs provided by this package. +# """, +# ) +# routing_table: Dict[str, List[RoutableProviderConfig]] = Field( +# default_factory=dict, +# description=""" + +# E.g. The following is a ProviderRoutingEntry for models: +# - routing_key: Llama3.1-8B-Instruct +# provider_type: meta-reference +# config: +# model: Llama3.1-8B-Instruct +# quantization: null +# torch_seed: null +# max_seq_len: 4096 +# max_batch_size: 1 +# """, +# ) -@json_schema_type class BuildConfig(BaseModel): version: str = LLAMA_STACK_BUILD_CONFIG_VERSION name: str diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index acd7ab7f8..07a851e78 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -6,15 +6,19 @@ from typing import Dict, List from llama_stack.apis.inspect import * # noqa: F403 - +from pydantic import BaseModel from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.providers.datatypes import * # noqa: F403 -def is_passthrough(spec: ProviderSpec) -> bool: - return isinstance(spec, RemoteProviderSpec) and spec.adapter is None +class DistributionInspectConfig(BaseModel): + pass + + +def get_provider_impl(*args, **kwargs): + return DistributionInspectImpl() class DistributionInspectImpl(Inspect): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index ae7d9ab40..ec8374290 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -12,138 +12,187 @@ from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, get_provider_registry, ) -from llama_stack.distribution.inspect import DistributionInspectImpl from llama_stack.distribution.utils.dynamic import instantiate_class_type +# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF! +class ProviderWithSpec(Provider): + spec: ProviderSpec + + async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: """ Does two things: - flatmaps, sorts and resolves the providers in dependency order - for each API, produces either a (local, passthrough or router) implementation """ - all_providers = get_provider_registry() - specs = {} - configs = {} + all_api_providers = get_provider_registry() - for api_str, config in run_config.api_providers.items(): + auto_routed_apis = builtin_automatically_routed_apis() + providers_with_specs = {} + + for api_str, instances in run_config.providers.items(): api = Api(api_str) - - # TODO: check that these APIs are not in the routing table part of the config - providers = all_providers[api] - - # skip checks for API whose provider config is specified in routing_table - if isinstance(config, PlaceholderProviderConfig): - continue - - if config.provider_type not in providers: + if api in [a.routing_table_api for a in auto_routed_apis]: raise ValueError( - f"Provider `{config.provider_type}` is not available for API `{api}`" + f"Provider for `{api_str}` is automatically provided and cannot be overridden" ) - specs[api] = providers[config.provider_type] - configs[api] = config + + providers_with_specs[api] = {} + for config in instances: + if config.provider_type not in all_api_providers[api]: + raise ValueError( + f"Provider `{config.provider_type}` is not available for API `{api}`" + ) + + spec = ProviderWithSpec( + spec=all_api_providers[api][config.provider_type], + **config, + ) + providers_with_specs[api][spec.provider_id] = spec apis_to_serve = run_config.apis_to_serve or set( - list(specs.keys()) + list(run_config.routing_table.keys()) + list(providers_with_specs.keys()) + + [a.routing_table_api.value for a in auto_routed_apis] ) for info in builtin_automatically_routed_apis(): - source_api = info.routing_table_api - - assert ( - source_api not in specs - ), f"Routing table API {source_api} specified in wrong place?" - assert ( - info.router_api not in specs - ), f"Auto-routed API {info.router_api} specified in wrong place?" - if info.router_api.value not in apis_to_serve: continue - if info.router_api.value not in run_config.routing_table: - raise ValueError(f"Routing table for `{source_api.value}` is not provided?") + if info.routing_table_api.value not in run_config: + raise ValueError( + f"Registry for `{info.routing_table_api.value}` is not provided?" + ) - routing_table = run_config.routing_table[info.router_api.value] + available_providers = providers_with_specs[info.router_api] - providers = all_providers[info.router_api] - - inner_specs = [] inner_deps = [] - for rt_entry in routing_table: - if rt_entry.provider_type not in providers: + registry = run_config[info.routing_table_api.value] + for entry in registry: + if entry.provider_id not in available_providers: raise ValueError( - f"Provider `{rt_entry.provider_type}` is not available for API `{api}`" + f"Provider `{entry.provider_id}` not found. Available providers: {list(available_providers.keys())}" ) - inner_specs.append(providers[rt_entry.provider_type]) - inner_deps.extend(providers[rt_entry.provider_type].api_dependencies) - specs[source_api] = RoutingTableProviderSpec( - api=source_api, - module="llama_stack.distribution.routers", - api_dependencies=inner_deps, - inner_specs=inner_specs, + provider = available_providers[entry.provider_id] + inner_deps.extend(provider.spec.api_dependencies) + + providers_with_specs[info.routing_table_api] = { + "__builtin__": [ + ProviderWithSpec( + provider_id="__builtin__", + provider_type="__builtin__", + config=registry, + spec=RoutingTableProviderSpec( + api=info.routing_table_api, + router_api=info.router_api, + module="llama_stack.distribution.routers", + api_dependencies=inner_deps, + ), + ) + ] + } + + providers_with_specs[info.router_api] = { + "__builtin__": [ + ProviderWithSpec( + provider_id="__builtin__", + provider_type="__builtin__", + config={}, + spec=AutoRoutedProviderSpec( + api=info.router_api, + module="llama_stack.distribution.routers", + routing_table_api=source_api, + api_dependencies=[source_api], + ), + ) + ] + } + + sorted_providers = topological_sort(providers_with_specs) + sorted_providers.append( + ProviderWithSpec( + provider_id="__builtin__", + provider_type="__builtin__", + config={}, + spec=InlineProviderSpec( + api=Api.inspect, + provider_type="__builtin__", + config_class="llama_stack.distribution.inspect.DistributionInspectConfig", + module="llama_stack.distribution.inspect", + ), ) - configs[source_api] = routing_table - - specs[info.router_api] = AutoRoutedProviderSpec( - api=info.router_api, - module="llama_stack.distribution.routers", - routing_table_api=source_api, - api_dependencies=[source_api], - ) - configs[info.router_api] = {} - - sorted_specs = topological_sort(specs.values()) - print(f"Resolved {len(sorted_specs)} providers in topological order") - for spec in sorted_specs: - print(f" {spec.api}: {spec.provider_type}") - print("") - impls = {} - for spec in sorted_specs: - api = spec.api - deps = {api: impls[api] for api in spec.api_dependencies} - impl = await instantiate_provider(spec, deps, configs[api]) - - impls[api] = impl - - impls[Api.inspect] = DistributionInspectImpl() - specs[Api.inspect] = InlineProviderSpec( - api=Api.inspect, - provider_type="__distribution_builtin__", - config_class="", - module="", ) - return impls, specs + print(f"Resolved {len(sorted_providers)} providers in topological order") + for provider in sorted_providers: + print( + f" {provider.spec.api}: ({provider.provider_id}) {provider.spec.provider_type}" + ) + print("") + impls = {} + + impls_by_provider_id = {} + for provider in sorted_providers: + api = provider.spec.api + if api not in impls_by_provider_id: + impls_by_provider_id[api] = {} + + deps = {api: impls[api] for api in provider.spec.api_dependencies} + + inner_impls = {} + if isinstance(provider.spec, RoutingTableProviderSpec): + for entry in provider.config: + inner_impls[entry.provider_id] = impls_by_provider_id[ + provider.spec.router_api + ][entry.provider_id] + + impl = await instantiate_provider( + provider, + deps, + inner_impls, + ) + + impls[api] = impl + impls_by_provider_id[api][provider.provider_id] = impl + + return impls -def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: - by_id = {x.api: x for x in providers} +def topological_sort( + providers_with_specs: Dict[Api, List[ProviderWithSpec]], +) -> List[ProviderWithSpec]: + def dfs(kv, visited: Set[Api], stack: List[Api]): + api, providers = kv + visited.add(api) - def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): - visited.add(a.api) - - for api in a.api_dependencies: + deps = [dep for x in providers for dep in x.api_dependencies] + for api in deps: if api not in visited: - dfs(by_id[api], visited, stack) + dfs((api, providers_with_specs[api]), visited, stack) - stack.append(a.api) + stack.append(api) visited = set() stack = [] - for a in providers: - if a.api not in visited: - dfs(a, visited, stack) + for api, providers in providers_with_specs.items(): + if api not in visited: + dfs((api, providers), visited, stack) - return [by_id[x] for x in stack] + flattened = [] + for api in stack: + flattened.extend(providers_with_specs[api]) + return flattened # returns a class implementing the protocol corresponding to the Api async def instantiate_provider( - provider_spec: ProviderSpec, + provider: ProviderWithSpec, deps: Dict[str, Any], - provider_config: Union[GenericProviderConfig, RoutingTable], + inner_impls: Dict[str, Any], ): + provider_spec = provider.spec module = importlib.import_module(provider_spec.module) args = [] @@ -165,21 +214,11 @@ async def instantiate_provider( elif isinstance(provider_spec, RoutingTableProviderSpec): method = "get_routing_table_impl" - assert isinstance(provider_config, List) - routing_table = provider_config - - inner_specs = {x.provider_type: x for x in provider_spec.inner_specs} - inner_impls = [] - for routing_entry in routing_table: - impl = await instantiate_provider( - inner_specs[routing_entry.provider_type], - deps, - routing_entry, - ) - inner_impls.append((routing_entry.routing_key, impl)) + assert isinstance(provider_config, list) + registry = provider_config config = None - args = [provider_spec.api, inner_impls, routing_table, deps] + args = [provider_spec.api, registry, inner_impls, deps] else: method = "get_provider_impl" diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 363c863aa..0464ab57a 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -4,23 +4,24 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, List, Tuple +from typing import Any, List from llama_stack.distribution.datatypes import * # noqa: F403 +from .routing_tables import ( + MemoryBanksRoutingTable, + ModelsRoutingTable, + RoutableObject, + RoutedProtocol, + ShieldsRoutingTable, +) async def get_routing_table_impl( api: Api, - inner_impls: List[Tuple[str, Any]], - routing_table_config: Dict[str, List[RoutableProviderConfig]], + registry: List[RoutableObject], + impls_by_provider_id: Dict[str, RoutedProtocol], _deps, ) -> Any: - from .routing_tables import ( - MemoryBanksRoutingTable, - ModelsRoutingTable, - ShieldsRoutingTable, - ) - api_to_tables = { "memory_banks": MemoryBanksRoutingTable, "models": ModelsRoutingTable, @@ -29,7 +30,7 @@ async def get_routing_table_impl( if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_tables[api.value](inner_impls, routing_table_config) + impl = api_to_tables[api.value](registry, impls_by_provider_id) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index e5db17edc..01d92ff12 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -4,141 +4,106 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Union -from llama_models.sku_list import resolve_model from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.inference import Inference +from llama_stack.apis.memory import Memory +from llama_stack.apis.safety import Safety from llama_stack.distribution.datatypes import * # noqa: F403 +RoutableObject = Union[ + ModelDef, + ShieldDef, + MemoryBankDef, +] + +RoutedProtocol = Union[ + Inference, + Safety, + Memory, +] + + class CommonRoutingTableImpl(RoutingTable): def __init__( self, - inner_impls: List[Tuple[RoutingKey, Any]], - routing_table_config: Dict[str, List[RoutableProviderConfig]], + registry: List[RoutableObject], + impls_by_provider_id: Dict[str, RoutedProtocol], ) -> None: - self.unique_providers = [] - self.providers = {} - self.routing_keys = [] + for obj in registry: + if obj.provider_id not in impls_by_provider_id: + raise ValueError( + f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found" + ) - for key, impl in inner_impls: - keys = key if isinstance(key, list) else [key] - self.unique_providers.append((keys, impl)) - - for k in keys: - if k in self.providers: - raise ValueError(f"Duplicate routing key {k}") - self.providers[k] = impl - self.routing_keys.append(k) - - self.routing_table_config = routing_table_config + self.impls_by_provider_id = impls_by_provider_id + self.registry = registry async def initialize(self) -> None: - for keys, p in self.unique_providers: + keys_by_provider = {} + for obj in self.registry: + keys = keys_by_provider.setdefault(obj.provider_id, []) + keys.append(obj.routing_key) + + for provider_id, keys in keys_by_provider.items(): + p = self.impls_by_provider_id[provider_id] spec = p.__provider_spec__ - if isinstance(spec, RemoteProviderSpec) and spec.adapter is None: + if is_passthrough(spec): continue await p.validate_routing_keys(keys) async def shutdown(self) -> None: - for _, p in self.unique_providers: - await p.shutdown() + pass def get_provider_impl(self, routing_key: str) -> Any: - if routing_key not in self.providers: + if routing_key not in self.routing_key_to_object: raise ValueError(f"Could not find provider for {routing_key}") - return self.providers[routing_key] + obj = self.routing_key_to_object[routing_key] + return self.impls_by_provider_id[obj.provider_id] - def get_routing_keys(self) -> List[str]: - return self.routing_keys - - def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]: - for entry in self.routing_table_config: - if entry.routing_key == routing_key: - return entry + def get_object_by_identifier(self, identifier: str) -> Optional[RoutableObject]: + for obj in self.registry: + if obj.identifier == identifier: + return obj return None class ModelsRoutingTable(CommonRoutingTableImpl, Models): + async def list_models(self) -> List[ModelDef]: + return self.registry - async def list_models(self) -> List[ModelServingSpec]: - specs = [] - for entry in self.routing_table_config: - model_id = entry.routing_key - specs.append( - ModelServingSpec( - llama_model=resolve_model(model_id), - provider_config=entry, - ) - ) - return specs + async def get_model(self, identifier: str) -> Optional[ModelDef]: + return self.get_object_by_identifier(identifier) - async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: - for entry in self.routing_table_config: - if entry.routing_key == core_model_id: - return ModelServingSpec( - llama_model=resolve_model(core_model_id), - provider_config=entry, - ) - return None + async def register_model(self, model: ModelDef) -> None: + raise NotImplementedError() class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): + async def list_shields(self) -> List[ShieldDef]: + return self.registry - async def list_shields(self) -> List[ShieldSpec]: - specs = [] - for entry in self.routing_table_config: - if isinstance(entry.routing_key, list): - for k in entry.routing_key: - specs.append( - ShieldSpec( - shield_type=k, - provider_config=entry, - ) - ) - else: - specs.append( - ShieldSpec( - shield_type=entry.routing_key, - provider_config=entry, - ) - ) - return specs + async def get_shield(self, shield_type: str) -> Optional[ShieldDef]: + return self.get_object_by_identifier(shield_type) - async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: - for entry in self.routing_table_config: - if entry.routing_key == shield_type: - return ShieldSpec( - shield_type=entry.routing_key, - provider_config=entry, - ) - return None + async def register_shield(self, shield: ShieldDef) -> None: + raise NotImplementedError() class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): + async def list_memory_banks(self) -> List[MemoryBankDef]: + return self.registry - async def list_available_memory_banks(self) -> List[MemoryBankSpec]: - specs = [] - for entry in self.routing_table_config: - specs.append( - MemoryBankSpec( - bank_type=entry.routing_key, - provider_config=entry, - ) - ) - return specs + async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: + return self.get_object_by_identifier(identifier) - async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]: - for entry in self.routing_table_config: - if entry.routing_key == bank_type: - return MemoryBankSpec( - bank_type=entry.routing_key, - provider_config=entry, - ) - return None + async def register_memory_bank(self, bank: MemoryBankDef) -> None: + raise NotImplementedError() diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 4013264df..f664bb674 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -285,7 +285,7 @@ def main( app = FastAPI() - impls, specs = asyncio.run(resolve_impls_with_routing(config)) + impls = asyncio.run(resolve_impls_with_routing(config)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) @@ -303,11 +303,7 @@ def main( endpoints = all_endpoints[api] impl = impls[api] - provider_spec = specs[api] - if ( - isinstance(provider_spec, RemoteProviderSpec) - and provider_spec.adapter is None - ): + if is_passthrough(impl.__provider_spec__): for endpoint in endpoints: url = impl.__provider_config__.url.rstrip("/") + endpoint.route getattr(app, endpoint.method)(endpoint.route)( diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index a2e8851a2..abc1d601d 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -154,6 +154,10 @@ as being "Llama Stack compatible" return None +def is_passthrough(spec: ProviderSpec) -> bool: + return isinstance(spec, RemoteProviderSpec) and spec.adapter is None + + # Can avoid this by using Pydantic computed_field def remote_provider_spec( api: Api, adapter: Optional[AdapterSpec] = None diff --git a/llama_stack/providers/impls/meta_reference/safety/config.py b/llama_stack/providers/impls/meta_reference/safety/config.py index 64a39b3c6..4f6de544b 100644 --- a/llama_stack/providers/impls/meta_reference/safety/config.py +++ b/llama_stack/providers/impls/meta_reference/safety/config.py @@ -9,7 +9,7 @@ from typing import List, Optional from llama_models.sku_list import CoreModelId, safety_models -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator class MetaReferenceShieldType(Enum): @@ -25,7 +25,7 @@ class LlamaGuardShieldConfig(BaseModel): disable_input_check: bool = False disable_output_check: bool = False - @validator("model") + @field_validator("model") @classmethod def validate_model(cls, model: str) -> str: permitted_models = [