From 85d927adde0fe4f9bdfe519cc0cf19547f244007 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 21 Sep 2024 13:44:33 -0700 Subject: [PATCH] skeleton unified routing table, api routers --- llama_stack/apis/inference/client.py | 12 +-- llama_stack/distribution/datatypes.py | 51 ++++++++--- llama_stack/distribution/routers/__init__.py | 27 ++++++ llama_stack/distribution/routers/routers.py | 84 +++++++++++++++++ .../distribution/routers/routing_table.py | 16 ++++ llama_stack/distribution/server/server.py | 50 +++++----- llama_stack/distribution/utils/dynamic.py | 29 +++--- llama_stack/examples/router-table-run.yaml | 59 +----------- llama_stack/providers/routers/__init__.py | 5 - .../providers/routers/memory/__init__.py | 17 ---- .../providers/routers/memory/memory.py | 91 ------------------- 11 files changed, 210 insertions(+), 231 deletions(-) create mode 100644 llama_stack/distribution/routers/__init__.py create mode 100644 llama_stack/distribution/routers/routers.py create mode 100644 llama_stack/distribution/routers/routing_table.py delete mode 100644 llama_stack/providers/routers/__init__.py delete mode 100644 llama_stack/providers/routers/memory/__init__.py delete mode 100644 llama_stack/providers/routers/memory/memory.py diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 8c75c6893..5c616a5c0 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -10,10 +10,10 @@ from typing import Any, AsyncGenerator import fire import httpx -from pydantic import BaseModel -from termcolor import cprint from llama_stack.distribution.datatypes import RemoteProviderConfig +from pydantic import BaseModel +from termcolor import cprint from .event_logger import EventLogger @@ -104,11 +104,9 @@ async def run_main(host: str, port: int, stream: bool): ) cprint(f"User>{message.content}", "green") iterator = client.chat_completion( - ChatCompletionRequest( - model="Meta-Llama3.1-8B-Instruct", - messages=[message], - stream=stream, - ) + model="Meta-Llama3.1-8B-Instruct", + messages=[message], + stream=stream, ) async for log in EventLogger().log(iterator): log.print() diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index a230dacf7..24b8443bf 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -43,6 +43,16 @@ class ProviderSpec(BaseModel): ) +class GenericProviderConfig(BaseModel): + provider_id: str + config: Dict[str, Any] + + +@json_schema_type +class ProviderRoutingEntry(GenericProviderConfig): + routing_key: str + + @json_schema_type class RouterProviderSpec(ProviderSpec): provider_id: str = "router" @@ -50,14 +60,20 @@ class RouterProviderSpec(ProviderSpec): docker_image: Optional[str] = None - inner_specs: List[ProviderSpec] + routing_table: List[ProviderRoutingEntry] = Field( + default_factory=list, + description="Routing table entries corresponding to the API", + ) module: str = Field( ..., description=""" -Fully-qualified name of the module to import. The module is expected to have: + Fully-qualified name of the module to import. The module is expected to have: - - `get_router_impl(config, provider_specs, deps)`: returns the router implementation -""", + - `get_router_impl(config, provider_specs, deps)`: returns the router implementation + """, + ) + provider_data_validator: Optional[str] = Field( + default=None, ) @property @@ -65,11 +81,6 @@ Fully-qualified name of the module to import. The module is expected to have: raise AssertionError("Should not be called on RouterProviderSpec") -class GenericProviderConfig(BaseModel): - provider_id: str - config: Dict[str, Any] - - @json_schema_type class AdapterSpec(BaseModel): adapter_id: str = Field( @@ -204,12 +215,7 @@ in the runtime configuration to help route to the correct provider.""", ) -@json_schema_type -class ProviderRoutingEntry(GenericProviderConfig): - routing_key: str - - -ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]] +ProviderMapEntry = GenericProviderConfig @json_schema_type @@ -248,6 +254,21 @@ As examples: The key may support wild-cards alsothe routing_key to route to the correct provider.""", ) + provider_routing_table: Dict[str, List[ProviderRoutingEntry]] = Field( + description=""" + API: List[ProviderRoutingEntry] map. Each ProviderRoutingEntry is a (routing_key, provider_config) tuple. + + E.g. The following is a ProviderRoutingEntry for inference API: + - routing_key: Meta-Llama3.1-8B-Instruct + provider_id: meta-reference + config: + model: Meta-Llama3.1-8B-Instruct + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + """ + ) @json_schema_type diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py new file mode 100644 index 000000000..707797aab --- /dev/null +++ b/llama_stack/distribution/routers/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Tuple + +from llama_stack.distribution.datatypes import Api + + +async def get_router_impl(api: str, provider_routing_table: Dict[str, Any]): + from .routers import InferenceRouter, MemoryRouter + from .routing_table import RoutingTable + + api2routers = { + "memory": MemoryRouter, + "inference": InferenceRouter, + } + + routing_table = RoutingTable(provider_routing_table) + routing_table.print() + + impl = api2routers[api](routing_table) + # impl = Router(api, provider_routing_table) + await impl.initialize() + return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py new file mode 100644 index 000000000..836db1b5f --- /dev/null +++ b/llama_stack/distribution/routers/routers.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Tuple + +from llama_stack.distribution.datatypes import Api + +from .routing_table import RoutingTable +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 + + +class MemoryRouter(Memory): + """Routes to an provider based on the memory bank type""" + + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def create_memory_bank( + self, + name: str, + config: MemoryBankConfig, + url: Optional[URL] = None, + ) -> MemoryBank: + print("MemoryRouter: create_memory_bank") + + async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: + print("MemoryRouter: get_memory_bank") + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + print("MemoryRouter: insert_documents") + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + print("query_documents") + + +class InferenceRouter(Inference): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + # zero-shot tool definitions as input to the model + tools: Optional[List[ToolDefinition]] = list, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: + print("Inference Router: chat_completion") diff --git a/llama_stack/distribution/routers/routing_table.py b/llama_stack/distribution/routers/routing_table.py new file mode 100644 index 000000000..ccc3f5b7c --- /dev/null +++ b/llama_stack/distribution/routers/routing_table.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import Any, Dict + + +class RoutingTable: + def __init__(self, provider_routing_table: Dict[str, Any]): + self.provider_routing_table = provider_routing_table + + def print(self): + print(f"ROUTING TABLE {self.provider_routing_table}") diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 94df176fc..468501980 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -50,7 +50,10 @@ from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import api_endpoints, api_providers from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.utils.dynamic import instantiate_provider +from llama_stack.distribution.utils.dynamic import ( + instantiate_provider, + instantiate_router, +) def is_async_iterator_type(typ): @@ -288,8 +291,8 @@ def snake_to_camel(snake_str): return "".join(word.capitalize() for word in snake_str.split("_")) -async def resolve_impls( - provider_map: Dict[str, ProviderMapEntry], +async def resolve_impls_with_routing( + stack_run_config: StackRunConfig, ) -> Dict[Api, Any]: """ Does two things: @@ -297,33 +300,28 @@ async def resolve_impls( - for each API, produces either a (local, passthrough or router) implementation """ all_providers = api_providers() - specs = {} - for api_str, item in provider_map.items(): + + for api_str in stack_run_config.apis_to_serve: api = Api(api_str) providers = all_providers[api] - if isinstance(item, GenericProviderConfig): - if item.provider_id not in providers: + # check for regular providers without routing + if api_str in stack_run_config.provider_map: + provider_map_entry = stack_run_config.provider_map[api_str] + if provider_map_entry.provider_id not in providers: raise ValueError( f"Unknown provider `{provider_id}` is not available for API `{api}`" ) - specs[api] = providers[item.provider_id] - else: - assert isinstance(item, list) - inner_specs = [] - for rt_entry in item: - if rt_entry.provider_id not in providers: - raise ValueError( - f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" - ) - inner_specs.append(providers[rt_entry.provider_id]) + specs[api] = providers[provider_map_entry.provider_id] + # check for routing table, we need to pass routing table to the router implementation + if api_str in stack_run_config.provider_routing_table: specs[api] = RouterProviderSpec( api=api, - module=f"llama_stack.providers.routers.{api.value.lower()}", + module=f"llama_stack.distribution.routers", api_dependencies=[], - inner_specs=inner_specs, + routing_table=stack_run_config.provider_routing_table[api_str], ) sorted_specs = topological_sort(specs.values()) @@ -331,9 +329,16 @@ async def resolve_impls( impls = {} for spec in sorted_specs: api = spec.api - deps = {api: impls[api] for api in spec.api_dependencies} - impl = await instantiate_provider(spec, deps, provider_map[api.value]) + if api.value in stack_run_config.provider_map: + provider_config = stack_run_config.provider_map[api.value] + impl = await instantiate_provider(spec, deps, provider_config) + elif api.value in stack_run_config.provider_routing_table: + impl = await instantiate_router( + spec, api.value, stack_run_config.provider_routing_table + ) + else: + raise ValueError(f"Cannot find provider_config for Api {api.value}") impls[api] = impl return impls, specs @@ -345,7 +350,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): app = FastAPI() - impls, specs = asyncio.run(resolve_impls(config.provider_map)) + # impls, specs = asyncio.run(resolve_impls(config.provider_map)) + impls, specs = asyncio.run(resolve_impls_with_routing(config)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 002a738ae..f807b096d 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -16,6 +16,19 @@ def instantiate_class_type(fully_qualified_name): return getattr(module, class_name) +async def instantiate_router( + provider_spec: RouterProviderSpec, + api: str, + provider_routing_table: Dict[str, Any], +): + module = importlib.import_module(provider_spec.module) + + fn = getattr(module, "get_router_impl") + impl = await fn(api, provider_routing_table) + impl.__provider_spec__ = provider_spec + return impl + + # returns a class implementing the protocol corresponding to the Api async def instantiate_provider( provider_spec: ProviderSpec, @@ -35,22 +48,6 @@ async def instantiate_provider( config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider_config.config) args = [config, deps] - elif isinstance(provider_spec, RouterProviderSpec): - method = "get_router_impl" - - assert isinstance(provider_config, list) - inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} - inner_impls = [] - for routing_entry in provider_config: - impl = await instantiate_provider( - inner_specs[routing_entry.provider_id], - deps, - routing_entry, - ) - inner_impls.append((routing_entry.routing_key, impl)) - - config = None - args = [inner_impls, deps] else: method = "get_provider_impl" diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index 9fbc394c1..74a82bebc 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -5,7 +5,7 @@ conda_env: local apis_to_serve: - inference # - memory -# - telemetry +- telemetry provider_map: telemetry: provider_id: meta-reference @@ -36,60 +36,3 @@ provider_routing_table: - routing_key: vector provider_id: meta-reference config: {} - - - -# safety: -# provider_id: meta-reference -# config: -# llama_guard_shield: -# model: Llama-Guard-3-8B -# excluded_categories: [] -# disable_input_check: false -# disable_output_check: false -# prompt_guard_shield: -# model: Prompt-Guard-86M -# telemetry: -# provider_id: meta-reference -# config: {} -# agents: -# provider_id: meta-reference -# config: {} -# memory: -# provider_id: meta-reference -# config: {} -# models: -# provider_id: builtin -# config: -# models_config: -# - core_model_id: Meta-Llama3.1-8B-Instruct -# provider_id: meta-reference -# api: inference -# config: -# model: Meta-Llama3.1-8B-Instruct -# quantization: null -# torch_seed: null -# max_seq_len: 4096 -# max_batch_size: 1 -# - core_model_id: Meta-Llama3.1-8B -# provider_id: meta-reference -# api: inference -# config: -# model: Meta-Llama3.1-8B -# quantization: null -# torch_seed: null -# max_seq_len: 4096 -# max_batch_size: 1 -# - core_model_id: Llama-Guard-3-8B -# provider_id: meta-reference -# api: safety -# config: -# model: Llama-Guard-3-8B -# excluded_categories: [] -# disable_input_check: false -# disable_output_check: false -# - core_model_id: Prompt-Guard-86M -# provider_id: meta-reference -# api: safety -# config: -# model: Prompt-Guard-86M diff --git a/llama_stack/providers/routers/__init__.py b/llama_stack/providers/routers/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/routers/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/routers/memory/__init__.py b/llama_stack/providers/routers/memory/__init__.py deleted file mode 100644 index d4dbbb1d4..000000000 --- a/llama_stack/providers/routers/memory/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, List, Tuple - -from llama_stack.distribution.datatypes import Api - - -async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]): - from .memory import MemoryRouterImpl - - impl = MemoryRouterImpl(inner_impls, deps) - await impl.initialize() - return impl diff --git a/llama_stack/providers/routers/memory/memory.py b/llama_stack/providers/routers/memory/memory.py deleted file mode 100644 index b96cde626..000000000 --- a/llama_stack/providers/routers/memory/memory.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict, List, Tuple - -from llama_stack.distribution.datatypes import Api -from llama_stack.apis.memory import * # noqa: F403 - - -class MemoryRouterImpl(Memory): - """Routes to an provider based on the memory bank type""" - - def __init__( - self, - inner_impls: List[Tuple[str, Any]], - deps: List[Api], - ) -> None: - self.deps = deps - - bank_types = [v.value for v in MemoryBankType] - - self.providers = {} - for routing_key, provider_impl in inner_impls: - if routing_key not in bank_types: - raise ValueError( - f"Unknown routing key `{routing_key}` for memory bank type" - ) - self.providers[routing_key] = provider_impl - - self.bank_id_to_type = {} - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - for p in self.providers.values(): - await p.shutdown() - - def get_provider(self, bank_type): - if bank_type not in self.providers: - raise ValueError(f"Memory bank type {bank_type} not supported") - - return self.providers[bank_type] - - async def create_memory_bank( - self, - name: str, - config: MemoryBankConfig, - url: Optional[URL] = None, - ) -> MemoryBank: - provider = self.get_provider(config.type) - bank = await provider.create_memory_bank(name, config, url) - self.bank_id_to_type[bank.bank_id] = config.type - return bank - - async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: - bank_type = self.bank_id_to_type.get(bank_id) - if not bank_type: - raise ValueError(f"Could not find bank type for {bank_id}") - - provider = self.get_provider(bank_type) - return await provider.get_memory_bank(bank_id) - - async def insert_documents( - self, - bank_id: str, - documents: List[MemoryBankDocument], - ttl_seconds: Optional[int] = None, - ) -> None: - bank_type = self.bank_id_to_type.get(bank_id) - if not bank_type: - raise ValueError(f"Could not find bank type for {bank_id}") - - provider = self.get_provider(bank_type) - return await provider.insert_documents(bank_id, documents, ttl_seconds) - - async def query_documents( - self, - bank_id: str, - query: InterleavedTextMedia, - params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: - bank_type = self.bank_id_to_type.get(bank_id) - if not bank_type: - raise ValueError(f"Could not find bank type for {bank_id}") - - provider = self.get_provider(bank_type) - return await provider.query_documents(bank_id, query, params)