From c1ab66f1e6b33f774a9ffea14543ec1c8113bfd0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 22 Sep 2024 16:31:18 -0700 Subject: [PATCH] Further generalize Xi's changes (#88) * Further generalize Xi's changes - introduce a slightly more general notion of an AutoRouted provider - the AutoRouted provider is associated with a RoutingTable provider - e.g. inference -> models - Introduced safety -> shields and memory -> memory_banks correspondences * typo * Basic build and run succeeded --- .../memory_banks/__init__.py} | 6 +- llama_stack/apis/memory_banks/client.py | 67 +++++++++ llama_stack/apis/memory_banks/memory_banks.py | 32 +++++ llama_stack/apis/models/client.py | 19 ++- llama_stack/apis/models/models.py | 25 +--- llama_stack/apis/shields/__init__.py | 7 + llama_stack/apis/shields/client.py | 67 +++++++++ llama_stack/apis/shields/shields.py | 28 ++++ llama_stack/distribution/datatypes.py | 74 +++++----- llama_stack/distribution/distribution.py | 34 +++++ llama_stack/distribution/routers/__init__.py | 50 +++++-- llama_stack/distribution/routers/routers.py | 70 ++++++---- .../distribution/routers/routing_table.py | 60 -------- .../distribution/routers/routing_tables.py | 118 ++++++++++++++++ llama_stack/distribution/server/server.py | 128 ++++++++++++------ llama_stack/distribution/utils/dynamic.py | 50 ++++--- .../impls/builtin/models/__init__.py | 21 --- .../providers/impls/builtin/models/models.py | 113 ---------------- .../impls/meta_reference/inference/config.py | 10 +- llama_stack/providers/registry/inference.py | 14 -- llama_stack/providers/registry/models.py | 22 --- 21 files changed, 597 insertions(+), 418 deletions(-) rename llama_stack/{providers/impls/builtin/models/config.py => apis/memory_banks/__init__.py} (58%) create mode 100644 llama_stack/apis/memory_banks/client.py create mode 100644 llama_stack/apis/memory_banks/memory_banks.py create mode 100644 llama_stack/apis/shields/__init__.py create mode 100644 llama_stack/apis/shields/client.py create mode 100644 llama_stack/apis/shields/shields.py delete mode 100644 llama_stack/distribution/routers/routing_table.py create mode 100644 llama_stack/distribution/routers/routing_tables.py delete mode 100644 llama_stack/providers/impls/builtin/models/__init__.py delete mode 100644 llama_stack/providers/impls/builtin/models/models.py delete mode 100644 llama_stack/providers/registry/models.py diff --git a/llama_stack/providers/impls/builtin/models/config.py b/llama_stack/apis/memory_banks/__init__.py similarity index 58% rename from llama_stack/providers/impls/builtin/models/config.py rename to llama_stack/apis/memory_banks/__init__.py index 0a21e3b20..7511677ab 100644 --- a/llama_stack/providers/impls/builtin/models/config.py +++ b/llama_stack/apis/memory_banks/__init__.py @@ -3,9 +3,5 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel - -@json_schema_type -class BuiltinImplConfig(BaseModel): ... +from .memory_banks import * # noqa: F401 F403 diff --git a/llama_stack/apis/memory_banks/client.py b/llama_stack/apis/memory_banks/client.py new file mode 100644 index 000000000..1a648927e --- /dev/null +++ b/llama_stack/apis/memory_banks/client.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +from typing import List, Optional + +import fire +import httpx +from termcolor import cprint + +from .memory_banks import * # noqa: F403 + + +class MemoryBanksClient(MemoryBanks): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def list_memory_banks(self) -> List[MemoryBankSpec]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/memory_banks/list", + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return [MemoryBankSpec(**x) for x in response.json()] + + async def get_memory_bank( + self, bank_type: MemoryBankType + ) -> Optional[MemoryBankSpec]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/memory_banks/get", + json={ + "bank_type": bank_type, + }, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + j = response.json() + if j is None: + return None + return MemoryBankSpec(**j) + + +async def run_main(host: str, port: int, stream: bool): + client = MemoryBanksClient(f"http://{host}:{port}") + + response = await client.list_memory_banks() + cprint(f"list_memory_banks response={response}", "green") + + +def main(host: str, port: int, stream: bool = True): + asyncio.run(run_main(host, port, stream)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py new file mode 100644 index 000000000..23bfb69e1 --- /dev/null +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List, Optional, Protocol + +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel, Field + +from llama_stack.apis.memory import MemoryBankType + +from llama_stack.distribution.datatypes import GenericProviderConfig + + +@json_schema_type +class MemoryBankSpec(BaseModel): + bank_type: MemoryBankType + provider_config: GenericProviderConfig = Field( + description="Provider config for the model, including provider_id, and corresponding config. ", + ) + + +class MemoryBanks(Protocol): + @webmethod(route="/memory_banks/list", method="GET") + async def list_memory_banks(self) -> List[MemoryBankSpec]: ... + + @webmethod(route="/memory_banks/get", method="GET") + async def get_memory_bank( + self, bank_type: MemoryBankType + ) -> Optional[MemoryBankSpec]: ... diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py index 929265f9e..d97184997 100644 --- a/llama_stack/apis/models/client.py +++ b/llama_stack/apis/models/client.py @@ -5,15 +5,11 @@ # the root directory of this source tree. import asyncio -import json -from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import List, Optional import fire import httpx - -from llama_stack.distribution.datatypes import RemoteProviderConfig from termcolor import cprint from .models import * # noqa: F403 @@ -29,18 +25,18 @@ class ModelsClient(Models): async def shutdown(self) -> None: pass - async def list_models(self) -> ModelsListResponse: + async def list_models(self) -> List[ModelServingSpec]: async with httpx.AsyncClient() as client: response = await client.get( f"{self.base_url}/models/list", headers={"Content-Type": "application/json"}, ) response.raise_for_status() - return ModelsListResponse(**response.json()) + return [ModelServingSpec(**x) for x in response.json()] - async def get_model(self, core_model_id: str) -> ModelsGetResponse: + async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: async with httpx.AsyncClient() as client: - response = await client.post( + response = await client.get( f"{self.base_url}/models/get", json={ "core_model_id": core_model_id, @@ -48,7 +44,10 @@ class ModelsClient(Models): headers={"Content-Type": "application/json"}, ) response.raise_for_status() - return ModelsGetResponse(**response.json()) + j = response.json() + if j is None: + return None + return ModelServingSpec(**j) async def run_main(host: str, port: int, stream: bool): diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 8bbe7f6de..d542517ba 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,14 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional, Protocol +from typing import List, Optional, Protocol from llama_models.llama3.api.datatypes import Model from llama_models.schema_utils import json_schema_type, webmethod -from llama_stack.distribution.datatypes import GenericProviderConfig from pydantic import BaseModel, Field +from llama_stack.distribution.datatypes import GenericProviderConfig + @json_schema_type class ModelServingSpec(BaseModel): @@ -21,25 +22,11 @@ class ModelServingSpec(BaseModel): provider_config: GenericProviderConfig = Field( description="Provider config for the model, including provider_id, and corresponding config. ", ) - api: str = Field( - description="The API that this model is serving (e.g. inference / safety).", - default="inference", - ) - - -@json_schema_type -class ModelsListResponse(BaseModel): - models_list: List[ModelServingSpec] - - -@json_schema_type -class ModelsGetResponse(BaseModel): - core_model_spec: Optional[ModelServingSpec] = None class Models(Protocol): @webmethod(route="/models/list", method="GET") - async def list_models(self) -> ModelsListResponse: ... + async def list_models(self) -> List[ModelServingSpec]: ... - @webmethod(route="/models/get", method="POST") - async def get_model(self, core_model_id: str) -> ModelsGetResponse: ... + @webmethod(route="/models/get", method="GET") + async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ... diff --git a/llama_stack/apis/shields/__init__.py b/llama_stack/apis/shields/__init__.py new file mode 100644 index 000000000..edad26100 --- /dev/null +++ b/llama_stack/apis/shields/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .shields import * # noqa: F401 F403 diff --git a/llama_stack/apis/shields/client.py b/llama_stack/apis/shields/client.py new file mode 100644 index 000000000..5e5001d90 --- /dev/null +++ b/llama_stack/apis/shields/client.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +from typing import List, Optional + +import fire +import httpx +from termcolor import cprint + +from .shields import * # noqa: F403 + + +class ShieldsClient(Shields): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def list_shields(self) -> List[ShieldSpec]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/shields/list", + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return [ShieldSpec(**x) for x in response.json()] + + async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/shields/get", + json={ + "shield_type": shield_type, + }, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + + j = response.json() + if j is None: + return None + + return ShieldSpec(**j) + + +async def run_main(host: str, port: int, stream: bool): + client = ShieldsClient(f"http://{host}:{port}") + + response = await client.list_shields() + cprint(f"list_shields response={response}", "green") + + +def main(host: str, port: int, stream: bool = True): + asyncio.run(run_main(host, port, stream)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py new file mode 100644 index 000000000..006178b5d --- /dev/null +++ b/llama_stack/apis/shields/shields.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List, Optional, Protocol + +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel, Field + +from llama_stack.distribution.datatypes import GenericProviderConfig + + +@json_schema_type +class ShieldSpec(BaseModel): + shield_type: str + provider_config: GenericProviderConfig = Field( + description="Provider config for the model, including provider_id, and corresponding config. ", + ) + + +class Shields(Protocol): + @webmethod(route="/shields/list", method="GET") + async def list_shields(self) -> List[ShieldSpec]: ... + + @webmethod(route="/shields/get", method="GET") + async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 5cc4e56ff..52522886e 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -6,7 +6,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type @@ -19,8 +19,12 @@ class Api(Enum): safety = "safety" agents = "agents" memory = "memory" + telemetry = "telemetry" + models = "models" + shields = "shields" + memory_banks = "memory_banks" @json_schema_type @@ -44,27 +48,36 @@ class ProviderSpec(BaseModel): ) +class RoutingTable(Protocol): + def get_routing_keys(self) -> List[str]: ... + + def get_provider_impl(self, routing_key: str) -> Any: ... + + class GenericProviderConfig(BaseModel): provider_id: str config: Dict[str, Any] -@json_schema_type -class ProviderRoutingEntry(GenericProviderConfig): +class RoutableProviderConfig(GenericProviderConfig): routing_key: str +class RoutingTableConfig(BaseModel): + entries: List[RoutableProviderConfig] = Field(...) + keys: Optional[List[str]] = Field( + default=None, + ) + + +# Example: /inference, /safety @json_schema_type -class RouterProviderSpec(ProviderSpec): +class AutoRoutedProviderSpec(ProviderSpec): provider_id: str = "router" config_class: str = "" docker_image: Optional[str] = None - - routing_table: List[ProviderRoutingEntry] = Field( - default_factory=list, - description="Routing table entries corresponding to the API", - ) + routing_table_api: Api module: str = Field( ..., description=""" @@ -79,18 +92,17 @@ class RouterProviderSpec(ProviderSpec): @property def pip_packages(self) -> List[str]: - raise AssertionError("Should not be called on RouterProviderSpec") + raise AssertionError("Should not be called on AutoRoutedProviderSpec") +# Example: /models, /shields @json_schema_type -class BuiltinProviderSpec(ProviderSpec): - provider_id: str = "builtin" +class RoutingTableProviderSpec(ProviderSpec): + provider_id: str = "routing_table" config_class: str = "" docker_image: Optional[str] = None - api_dependencies: List[Api] = [] - provider_data_validator: Optional[str] = Field( - default=None, - ) + + inner_specs: List[ProviderSpec] module: str = Field( ..., description=""" @@ -99,10 +111,7 @@ class BuiltinProviderSpec(ProviderSpec): - `get_router_impl(config, provider_specs, deps)`: returns the router implementation """, ) - pip_packages: List[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) + pip_packages: List[str] = Field(default_factory=list) @json_schema_type @@ -130,10 +139,6 @@ Fully-qualified name of the module to import. The module is expected to have: provider_data_validator: Optional[str] = Field( default=None, ) - supported_model_ids: List[str] = Field( - default_factory=list, - description="The list of model ids that this adapter supports", - ) @json_schema_type @@ -243,9 +248,6 @@ in the runtime configuration to help route to the correct provider.""", ) -ProviderMapEntry = GenericProviderConfig - - @json_schema_type class StackRunConfig(BaseModel): built_at: datetime @@ -269,25 +271,17 @@ this could be just a hash description=""" The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", ) - provider_map: Dict[str, ProviderMapEntry] = Field( + + api_providers: Dict[str, GenericProviderConfig] = Field( description=""" Provider configurations for each of the APIs provided by this package. - -Given an API, you can specify a single provider or a "routing table". Each entry in the routing -table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific. - -As examples: -- the "inference" API interprets the routing_key as a "model" -- the "memory" API interprets the routing_key as the type of a "memory bank" - -The key may support wild-cards alsothe routing_key to route to the correct provider.""", +""", ) - provider_routing_table: Dict[str, List[ProviderRoutingEntry]] = Field( + routing_tables: Dict[str, RoutingTableConfig] = Field( default_factory=dict, description=""" - API: List[ProviderRoutingEntry] map. Each ProviderRoutingEntry is a (routing_key, provider_config) tuple. - E.g. The following is a ProviderRoutingEntry for inference API: + E.g. The following is a ProviderRoutingEntry for models: - routing_key: Meta-Llama3.1-8B-Instruct provider_id: meta-reference config: diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 3dd406ccc..6b72afed5 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -8,11 +8,15 @@ import importlib import inspect from typing import Dict, List +from pydantic import BaseModel + from llama_stack.apis.agents import Agents from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.models import Models from llama_stack.apis.safety import Safety +from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec @@ -30,6 +34,28 @@ def stack_apis() -> List[Api]: return [v for v in Api] +class AutoRoutedApiInfo(BaseModel): + routing_table_api: Api + router_api: Api + + +def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: + return [ + AutoRoutedApiInfo( + routing_table_api=Api.models, + router_api=Api.inference, + ), + AutoRoutedApiInfo( + routing_table_api=Api.shields, + router_api=Api.safety, + ), + AutoRoutedApiInfo( + routing_table_api=Api.memory_banks, + router_api=Api.memory, + ), + ] + + def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: apis = {} @@ -40,6 +66,8 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: Api.memory: Memory, Api.telemetry: Telemetry, Api.models: Models, + Api.shields: Shields, + Api.memory_banks: MemoryBanks, } for api, protocol in protocols.items(): @@ -68,7 +96,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: ret = {} + routing_table_apis = set( + x.routing_table_api for x in builtin_automatically_routed_apis() + ) for api in stack_apis(): + if api in routing_table_apis: + continue + name = api.name.lower() module = importlib.import_module(f"llama_stack.providers.registry.{name}") ret[api] = { diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 9f26cdf38..e8b8938b0 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -4,25 +4,47 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Tuple +from typing import Any, List, Tuple -from llama_stack.distribution.datatypes import Api, ProviderRoutingEntry +from llama_stack.distribution.datatypes import * # noqa: F403 -async def get_router_impl( - api: str, provider_routing_table: Dict[str, List[ProviderRoutingEntry]] -): - from .routers import InferenceRouter, MemoryRouter - from .routing_table import RoutingTable +async def get_routing_table_impl( + api: Api, + inner_impls: List[Tuple[str, Any]], + routing_table_config: RoutingTableConfig, + _deps, +) -> Any: + from .routing_tables import ( + MemoryBanksRoutingTable, + ModelsRoutingTable, + ShieldsRoutingTable, + ) - api2routers = { - "memory": MemoryRouter, - "inference": InferenceRouter, + api_to_tables = { + "memory_banks": MemoryBanksRoutingTable, + "models": ModelsRoutingTable, + "shields": ShieldsRoutingTable, } + if api.value not in api_to_tables: + raise ValueError(f"API {api.value} not found in router map") - # initialize routing table with concrete provider impls - routing_table = RoutingTable(provider_routing_table) - - impl = api2routers[api](routing_table) + impl = api_to_tables[api.value](inner_impls, routing_table_config) + await impl.initialize() + return impl + + +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: + from .routers import InferenceRouter, MemoryRouter, SafetyRouter + + api_to_routers = { + "memory": MemoryRouter, + "inference": InferenceRouter, + "safety": SafetyRouter, + } + if api.value not in api_to_routers: + raise ValueError(f"API {api.value} not found in router map") + + impl = api_to_routers[api.value](routing_table) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index fe70cd701..6d296d20e 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,17 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Tuple +from typing import Any, AsyncGenerator, Dict, List -from llama_stack.distribution.datatypes import Api +from llama_stack.distribution.datatypes import RoutingTable -from .routing_table import RoutingTable from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 - -from types import MethodType - -from termcolor import cprint +from llama_stack.apis.safety import * # noqa: F403 class MemoryRouter(Memory): @@ -24,22 +20,24 @@ class MemoryRouter(Memory): self, routing_table: RoutingTable, ) -> None: - self.api = Api.memory.value self.routing_table = routing_table self.bank_id_to_type = {} async def initialize(self) -> None: - await self.routing_table.initialize(self.api) + pass async def shutdown(self) -> None: - await self.routing_table.shutdown(self.api) + pass def get_provider_from_bank_id(self, bank_id: str) -> Any: bank_type = self.bank_id_to_type.get(bank_id) if not bank_type: raise ValueError(f"Could not find bank type for {bank_id}") - return self.routing_table.get_provider_impl(self.api, bank_type) + provider = self.routing_table.get_provider_impl(bank_type) + if not provider: + raise ValueError(f"Could not find provider for {bank_type}") + return provider async def create_memory_bank( self, @@ -48,14 +46,15 @@ class MemoryRouter(Memory): url: Optional[URL] = None, ) -> MemoryBank: bank_type = config.type - bank = await self.routing_table.get_provider_impl( - self.api, bank_type + provider = await self.routing_table.get_provider_impl( + bank_type ).create_memory_bank(name, config, url) self.bank_id_to_type[bank.bank_id] = bank_type return bank async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: - return await self.get_provider_from_bank_id(bank_id).get_memory_bank(bank_id) + provider = self.get_provider_from_bank_id(bank_id) + return await provider.get_memory_bank(bank_id) async def insert_documents( self, @@ -85,34 +84,31 @@ class InferenceRouter(Inference): self, routing_table: RoutingTable, ) -> None: - self.api = Api.inference.value self.routing_table = routing_table async def initialize(self) -> None: - await self.routing_table.initialize(self.api) + pass async def shutdown(self) -> None: - await self.routing_table.shutdown(self.api) + pass async def chat_completion( self, model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - tools: Optional[List[ToolDefinition]] = [], + tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: # TODO: we need to fix streaming response to align provider implementations with Protocol. - async for chunk in self.routing_table.get_provider_impl( - self.api, model - ).chat_completion( + async for chunk in self.routing_table.get_provider_impl(model).chat_completion( model=model, messages=messages, sampling_params=sampling_params, - tools=tools, + tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, stream=stream, @@ -128,7 +124,7 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - return await self.routing_table.get_provider_impl(self.api, model).completion( + return await self.routing_table.get_provider_impl(model).completion( model=model, content=content, sampling_params=sampling_params, @@ -141,7 +137,33 @@ class InferenceRouter(Inference): model: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - return await self.routing_table.get_provider_impl(self.api, model).embeddings( + return await self.routing_table.get_provider_impl(model).embeddings( model=model, contents=contents, ) + + +class SafetyRouter(Safety): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def run_shield( + self, + shield_type: str, + messages: List[Message], + params: Dict[str, Any] = None, + ) -> RunShieldResponse: + return await self.routing_table.get_provider_impl(shield_type).run_shield( + shield_type=shield_type, + messages=messages, + params=params, + ) diff --git a/llama_stack/distribution/routers/routing_table.py b/llama_stack/distribution/routers/routing_table.py deleted file mode 100644 index 46cf40155..000000000 --- a/llama_stack/distribution/routers/routing_table.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from typing import Any, Dict, List - -from llama_stack.distribution.datatypes import ( - Api, - GenericProviderConfig, - ProviderRoutingEntry, -) -from llama_stack.distribution.distribution import api_providers -from llama_stack.distribution.utils.dynamic import instantiate_provider -from termcolor import cprint - - -class RoutingTable: - def __init__(self, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]): - self.provider_routing_table = provider_routing_table - # map {api: {routing_key: impl}}, e.g. {'inference': {'8b': , '70b': }} - self.api2routes = {} - - async def initialize(self, api_str: str) -> None: - """Initialize the routing table with concrete provider impls""" - if api_str not in self.provider_routing_table: - raise ValueError(f"API {api_str} not found in routing table") - - providers = api_providers()[Api(api_str)] - routing_list = self.provider_routing_table[api_str] - - self.api2routes[api_str] = {} - for rt_entry in routing_list: - rt_key = rt_entry.routing_key - provider_id = rt_entry.provider_id - impl = await instantiate_provider( - providers[provider_id], - deps=[], - provider_config=GenericProviderConfig( - provider_id=provider_id, config=rt_entry.config - ), - ) - cprint(f"impl = {impl}", "red") - self.api2routes[api_str][rt_key] = impl - - cprint(f"> Initialized implementations for {api_str} in routing table", "blue") - - async def shutdown(self, api_str: str) -> None: - """Shutdown the routing table""" - if api_str not in self.api2routes: - return - - for impl in self.api2routes[api_str].values(): - await impl.shutdown() - - def get_provider_impl(self, api: str, routing_key: str) -> Any: - """Get the provider impl for a given api and routing key""" - return self.api2routes[api][routing_key] diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py new file mode 100644 index 000000000..a3f40b2b7 --- /dev/null +++ b/llama_stack/distribution/routers/routing_tables.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, List, Optional, Tuple + +from llama_models.sku_list import resolve_model +from llama_models.llama3.api.datatypes import * # noqa: F403 + +from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.memory_banks import * # noqa: F403 + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +class CommonRoutingTableImpl(RoutingTable): + def __init__( + self, + inner_impls: List[Tuple[str, Any]], + routing_table_config: RoutingTableConfig, + ) -> None: + self.providers = {k: v for k, v in inner_impls} + self.routing_keys = list(self.providers.keys()) + self.routing_table_config = routing_table_config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + for p in self.providers.values(): + await p.shutdown() + + async def get_provider_impl(self, routing_key: str) -> Optional[Any]: + return self.providers.get(routing_key) + + async def get_routing_keys(self) -> List[str]: + return self.routing_keys + + async def get_provider_config( + self, routing_key: str + ) -> Optional[GenericProviderConfig]: + for entry in self.routing_table_config.entries: + if entry.routing_key == routing_key: + return entry + return None + + +class ModelsRoutingTable(CommonRoutingTableImpl, Models): + + async def list_models(self) -> List[ModelServingSpec]: + specs = [] + for entry in self.routing_table_config.entries: + model_id = entry.routing_key + specs.append( + ModelServingSpec( + llama_model=resolve_model(model_id), + provider_config=entry, + ) + ) + return specs + + async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: + for entry in self.routing_table_config.entries: + if entry.routing_key == core_model_id: + return ModelServingSpec( + llama_model=resolve_model(core_model_id), + provider_config=entry, + ) + return None + + +class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): + + async def list_shields(self) -> List[ShieldSpec]: + specs = [] + for entry in self.routing_table_config.entries: + specs.append( + ShieldSpec( + shield_type=entry.routing_key, + provider_config=entry, + ) + ) + return specs + + async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: + for entry in self.routing_table_config.entries: + if entry.routing_key == shield_type: + return ShieldSpec( + shield_type=entry.routing_key, + provider_config=entry, + ) + return None + + +class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): + + async def list_memory_banks(self) -> List[MemoryBankSpec]: + specs = [] + for entry in self.routing_table_config.entries: + specs.append( + MemoryBankSpec( + bank_type=entry.routing_key, + provider_config=entry, + ) + ) + return specs + + async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]: + for entry in self.routing_table_config.entries: + if entry.routing_key == bank_type: + return MemoryBankSpec( + bank_type=entry.routing_key, + provider_config=entry, + ) + return None diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 76d467881..18433596f 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import importlib import inspect import json import signal @@ -36,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -43,18 +45,15 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.distribution.distribution import api_endpoints, api_providers -from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.utils.dynamic import ( - instantiate_builtin_provider, - instantiate_provider, - instantiate_router, +from llama_stack.distribution.distribution import ( + api_endpoints, + api_providers, + builtin_automatically_routed_apis, ) +from llama_stack.distribution.request_headers import set_request_provider_data +from llama_stack.distribution.utils.dynamic import instantiate_provider def is_async_iterator_type(typ): @@ -292,9 +291,7 @@ def snake_to_camel(snake_str): return "".join(word.capitalize() for word in snake_str.split("_")) -async def resolve_impls_with_routing( - stack_run_config: StackRunConfig, -) -> Dict[Api, Any]: +async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: """ Does two things: - flatmaps, sorts and resolves the providers in dependency order @@ -302,48 +299,80 @@ async def resolve_impls_with_routing( """ all_providers = api_providers() specs = {} + configs = {} - for api_str in stack_run_config.apis_to_serve: + for api_str, config in run_config.api_providers.items(): api = Api(api_str) + + # TODO: check that these APIs are not in the routing table part of the config providers = all_providers[api] - # check for routing table, we need to pass routing table to the router implementation - if api_str in stack_run_config.provider_routing_table: - specs[api] = RouterProviderSpec( - api=api, - module=f"llama_stack.distribution.routers", - api_dependencies=[], - routing_table=stack_run_config.provider_routing_table[api_str], + if config.provider_id not in providers: + raise ValueError( + f"Unknown provider `{config.provider_id}` is not available for API `{api}`" ) - else: - if api_str in stack_run_config.provider_map: - provider_map_entry = stack_run_config.provider_map[api_str] - provider_id = provider_map_entry.provider_id - else: - # not defined in config, will be a builtin provider, assign builtin provider id - provider_id = "builtin" + specs[api] = providers[config.provider_id] + configs[api] = config - if provider_id not in providers: + apis_to_serve = run_config.apis_to_serve or set( + list(specs.keys()) + list(run_config.routing_tables.keys()) + ) + print("apis_to_serve", apis_to_serve) + for info in builtin_automatically_routed_apis(): + source_api = info.routing_table_api + + assert ( + source_api not in specs + ), f"Routing table API {source_api} specified in wrong place?" + assert ( + info.router_api not in specs + ), f"Auto-routed API {info.router_api} specified in wrong place?" + + if info.router_api.value not in apis_to_serve: + continue + + if source_api.value not in run_config.routing_tables: + raise ValueError(f"Routing table for `{source_api.value}` is not provided?") + + routing_table = run_config.routing_tables[source_api.value] + + providers = all_providers[info.router_api] + + inner_specs = [] + for rt_entry in routing_table.entries: + if rt_entry.provider_id not in providers: raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api}`" + f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" ) - specs[api] = providers[provider_id] + inner_specs.append(providers[rt_entry.provider_id]) + + specs[source_api] = RoutingTableProviderSpec( + api=source_api, + module="llama_stack.distribution.routers", + api_dependencies=[], + inner_specs=inner_specs, + ) + configs[source_api] = routing_table + + specs[info.router_api] = AutoRoutedProviderSpec( + api=info.router_api, + module="llama_stack.distribution.routers", + routing_table_api=source_api, + api_dependencies=[source_api], + ) + configs[info.router_api] = {} sorted_specs = topological_sort(specs.values()) - + print(f"Resolved {len(sorted_specs)} providers in topological order") + for spec in sorted_specs: + print(f" {spec.api}: {spec.provider_id}") + print("") impls = {} for spec in sorted_specs: api = spec.api deps = {api: impls[api] for api in spec.api_dependencies} - if api.value in stack_run_config.provider_map: - provider_config = stack_run_config.provider_map[api.value] - impl = await instantiate_provider(spec, deps, provider_config) - elif api.value in stack_run_config.provider_routing_table: - impl = await instantiate_router( - spec, api.value, stack_run_config.provider_routing_table - ) - else: - impl = await instantiate_builtin_provider(spec, stack_run_config) + impl = await instantiate_provider(spec, deps, configs[api]) + impls[api] = impl return impls, specs @@ -355,16 +384,23 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): app = FastAPI() - # impls, specs = asyncio.run(resolve_impls(config.provider_map)) impls, specs = asyncio.run(resolve_impls_with_routing(config)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) all_endpoints = api_endpoints() - apis_to_serve = config.apis_to_serve or list(config.provider_map.keys()) + if config.apis_to_serve: + apis_to_serve = set(config.apis_to_serve) + for inf in builtin_automatically_routed_apis(): + if inf.router_api.value in apis_to_serve: + apis_to_serve.add(inf.routing_table_api) + else: + apis_to_serve = set(impls.keys()) + for api_str in apis_to_serve: api = Api(api_str) + endpoints = all_endpoints[api] impl = impls[api] @@ -391,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): create_dynamic_typed_route( impl_method, endpoint.method, - provider_spec.provider_data_validator, + ( + provider_spec.provider_data_validator + if not isinstance(provider_spec, RoutingTableProviderSpec) + else None + ), ) ) diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 85254b246..6d9c57dfd 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -16,36 +16,11 @@ def instantiate_class_type(fully_qualified_name): return getattr(module, class_name) -async def instantiate_router( - provider_spec: RouterProviderSpec, - api: str, - provider_routing_table: Dict[str, Any], -): - module = importlib.import_module(provider_spec.module) - - fn = getattr(module, "get_router_impl") - impl = await fn(api, provider_routing_table) - impl.__provider_spec__ = provider_spec - return impl - - -async def instantiate_builtin_provider( - provider_spec: BuiltinProviderSpec, - run_config: StackRunConfig, -): - print("!!! instantiate_builtin_provider") - module = importlib.import_module(provider_spec.module) - fn = getattr(module, "get_builtin_impl") - impl = await fn(run_config) - impl.__provider_spec__ = provider_spec - return impl - - # returns a class implementing the protocol corresponding to the Api async def instantiate_provider( provider_spec: ProviderSpec, deps: Dict[str, Any], - provider_config: ProviderMapEntry, + provider_config: Union[GenericProviderConfig, RoutingTable], ): module = importlib.import_module(provider_spec.module) @@ -60,6 +35,29 @@ async def instantiate_provider( config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider_config.config) args = [config, deps] + elif isinstance(provider_spec, AutoRoutedProviderSpec): + method = "get_auto_router_impl" + + config = None + args = [provider_spec.api, deps[provider_spec.routing_table_api], deps] + elif isinstance(provider_spec, RoutingTableProviderSpec): + method = "get_routing_table_impl" + + assert isinstance(provider_config, RoutingTableConfig) + routing_table = provider_config + + inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} + inner_impls = [] + for routing_entry in routing_table.entries: + impl = await instantiate_provider( + inner_specs[routing_entry.provider_id], + deps, + routing_entry, + ) + inner_impls.append((routing_entry.routing_key, impl)) + + config = None + args = [provider_spec.api, inner_impls, routing_table, deps] else: method = "get_provider_impl" diff --git a/llama_stack/providers/impls/builtin/models/__init__.py b/llama_stack/providers/impls/builtin/models/__init__.py deleted file mode 100644 index cd969917e..000000000 --- a/llama_stack/providers/impls/builtin/models/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict - -from llama_stack.distribution.datatypes import Api, ProviderSpec, StackRunConfig - -from .config import BuiltinImplConfig # noqa - - -async def get_builtin_impl(config: StackRunConfig): - from .models import BuiltinModelsImpl - - assert isinstance(config, StackRunConfig), f"Unexpected config type: {type(config)}" - - impl = BuiltinModelsImpl(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/impls/builtin/models/models.py b/llama_stack/providers/impls/builtin/models/models.py deleted file mode 100644 index 74d1299a4..000000000 --- a/llama_stack/providers/impls/builtin/models/models.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import asyncio - -from typing import AsyncIterator, Union - -from llama_models.llama3.api.datatypes import StopReason -from llama_models.sku_list import resolve_model - -from llama_stack.distribution.distribution import Api, api_providers - -from llama_stack.apis.models import * # noqa: F403 -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_models.datatypes import CoreModelId, Model -from llama_models.sku_list import resolve_model - -from llama_stack.distribution.datatypes import * # noqa: F403 -from termcolor import cprint - - -class BuiltinModelsImpl(Models): - def __init__( - self, - config: StackRunConfig, - ) -> None: - self.run_config = config - self.models = {} - # check against inference & safety api - apis_with_models = [Api.inference, Api.safety] - - all_providers = api_providers() - - for api in apis_with_models: - - # check against provider_map (simple case single model) - if api.value in config.provider_map: - providers_for_api = all_providers[api] - provider_spec = config.provider_map[api.value] - core_model_id = provider_spec.config - # get supported model ids from the provider - supported_model_ids = self.get_supported_model_ids( - api.value, provider_spec, providers_for_api - ) - for model_id in supported_model_ids: - self.models[model_id] = ModelServingSpec( - llama_model=resolve_model(model_id), - provider_config=provider_spec, - api=api.value, - ) - - # check against provider_routing_table (router with multiple models) - # with routing table, we use the routing_key as the supported models - if api.value in config.provider_routing_table: - routing_table = config.provider_routing_table[api.value] - for rt_entry in routing_table: - model_id = rt_entry.routing_key - self.models[model_id] = ModelServingSpec( - llama_model=resolve_model(model_id), - provider_config=GenericProviderConfig( - provider_id=rt_entry.provider_id, - config=rt_entry.config, - ), - api=api.value, - ) - - print("BuiltinModelsImpl models", self.models) - - def get_supported_model_ids( - self, - api: str, - provider_spec: GenericProviderConfig, - providers_for_api: Dict[str, ProviderSpec], - ) -> List[str]: - serving_models_list = [] - if api == Api.inference.value: - provider_id = provider_spec.provider_id - if provider_id == "meta-reference": - serving_models_list.append(provider_spec.config["model"]) - if provider_id in { - remote_provider_id("ollama"), - remote_provider_id("fireworks"), - remote_provider_id("together"), - }: - adapter_supported_models = providers_for_api[ - provider_id - ].adapter.supported_model_ids - serving_models_list.extend(adapter_supported_models) - elif api == Api.safety.value: - if provider_spec.config and "llama_guard_shield" in provider_spec.config: - llama_guard_shield = provider_spec.config["llama_guard_shield"] - serving_models_list.append(llama_guard_shield["model"]) - if provider_spec.config and "prompt_guard_shield" in provider_spec.config: - prompt_guard_shield = provider_spec.config["prompt_guard_shield"] - serving_models_list.append(prompt_guard_shield["model"]) - else: - raise NotImplementedError(f"Unsupported api {api} for builtin models") - - return serving_models_list - - async def initialize(self) -> None: - pass - - async def list_models(self) -> ModelsListResponse: - return ModelsListResponse(models_list=list(self.models.values())) - - async def get_model(self, core_model_id: str) -> ModelsGetResponse: - if core_model_id in self.models: - return ModelsGetResponse(core_model_spec=self.models[core_model_id]) - print(f"Cannot find {core_model_id} in model registry") - return ModelsGetResponse() diff --git a/llama_stack/providers/impls/meta_reference/inference/config.py b/llama_stack/providers/impls/meta_reference/inference/config.py index 8e3d3ed3c..d9b397571 100644 --- a/llama_stack/providers/impls/meta_reference/inference/config.py +++ b/llama_stack/providers/impls/meta_reference/inference/config.py @@ -6,17 +6,14 @@ from typing import Optional -from llama_models.datatypes import ModelFamily - -from llama_models.schema_utils import json_schema_type +from llama_models.datatypes import * # noqa: F403 from llama_models.sku_list import all_registered_models, resolve_model +from llama_stack.apis.inference import * # noqa: F401, F403 + from pydantic import BaseModel, Field, field_validator -from llama_stack.apis.inference import QuantizationConfig - -@json_schema_type class MetaReferenceImplConfig(BaseModel): model: str = Field( default="Meta-Llama3.1-8B-Instruct", @@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel): m.descriptor() for m in all_registered_models() if m.model_family == ModelFamily.llama3_1 + or m.core_model_id == CoreModelId.llama_guard_3_8b ] if model not in permitted_models: model_list = "\n\t".join(permitted_models) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index bf739eefa..10b3d6ccc 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -32,10 +32,6 @@ def available_providers() -> List[ProviderSpec]: adapter_id="ollama", pip_packages=["ollama"], module="llama_stack.providers.adapters.inference.ollama", - supported_model_ids=[ - "Meta-Llama3.1-8B-Instruct", - "Meta-Llama3.1-70B-Instruct", - ], ), ), remote_provider_spec( @@ -56,11 +52,6 @@ def available_providers() -> List[ProviderSpec]: ], module="llama_stack.providers.adapters.inference.fireworks", config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig", - supported_model_ids=[ - "Meta-Llama3.1-8B-Instruct", - "Meta-Llama3.1-70B-Instruct", - "Meta-Llama3.1-405B-Instruct", - ], ), ), remote_provider_spec( @@ -73,11 +64,6 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.adapters.inference.together", config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig", header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor", - supported_model_ids=[ - "Meta-Llama3.1-8B-Instruct", - "Meta-Llama3.1-70B-Instruct", - "Meta-Llama3.1-405B-Instruct", - ], ), ), ] diff --git a/llama_stack/providers/registry/models.py b/llama_stack/providers/registry/models.py deleted file mode 100644 index 47ec948c4..000000000 --- a/llama_stack/providers/registry/models.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import List - -from llama_stack.distribution.datatypes import * # noqa: F403 - - -def available_providers() -> List[ProviderSpec]: - return [ - BuiltinProviderSpec( - api=Api.models, - provider_id="builtin", - pip_packages=[], - module="llama_stack.providers.impls.builtin.models", - config_class="llama_stack.providers.impls.builtin.models.BuiltinImplConfig", - api_dependencies=[], - ) - ]