From 951cc9d7b7f18955b214a901471dd9a18c3a2277 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 21 Sep 2024 14:26:48 -0700 Subject: [PATCH] router table registration works --- llama_stack/distribution/routers/__init__.py | 9 ++-- llama_stack/distribution/routers/routers.py | 25 +++++++-- .../distribution/routers/routing_table.py | 52 +++++++++++++++++-- llama_stack/examples/router-table-run.yaml | 32 +++++++----- 4 files changed, 91 insertions(+), 27 deletions(-) diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 707797aab..9f26cdf38 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -6,10 +6,12 @@ from typing import Any, Dict, List, Tuple -from llama_stack.distribution.datatypes import Api +from llama_stack.distribution.datatypes import Api, ProviderRoutingEntry -async def get_router_impl(api: str, provider_routing_table: Dict[str, Any]): +async def get_router_impl( + api: str, provider_routing_table: Dict[str, List[ProviderRoutingEntry]] +): from .routers import InferenceRouter, MemoryRouter from .routing_table import RoutingTable @@ -18,10 +20,9 @@ async def get_router_impl(api: str, provider_routing_table: Dict[str, Any]): "inference": InferenceRouter, } + # initialize routing table with concrete provider impls routing_table = RoutingTable(provider_routing_table) - routing_table.print() impl = api2routers[api](routing_table) - # impl = Router(api, provider_routing_table) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 836db1b5f..7d62dc76c 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Tuple +from typing import Any, AsyncGenerator, Dict, List, Tuple from llama_stack.distribution.datatypes import Api @@ -57,17 +57,20 @@ class MemoryRouter(Memory): class InferenceRouter(Inference): + """Routes to an provider based on the model""" + def __init__( self, routing_table: RoutingTable, ) -> None: + self.api = Api.inference.value self.routing_table = routing_table async def initialize(self) -> None: - pass + await self.routing_table.initialize(self.api) async def shutdown(self) -> None: - pass + await self.routing_table.shutdown(self.api) async def chat_completion( self, @@ -80,5 +83,17 @@ class InferenceRouter(Inference): tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: - print("Inference Router: chat_completion") + ) -> AsyncGenerator: + provider_impl = self.routing_table.get_provider_impl(self.api, model) + print("InferenceRouter: chat_completion", provider_impl) + async for chunk in provider_impl.chat_completion( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ): + yield chunk diff --git a/llama_stack/distribution/routers/routing_table.py b/llama_stack/distribution/routers/routing_table.py index ccc3f5b7c..46cf40155 100644 --- a/llama_stack/distribution/routers/routing_table.py +++ b/llama_stack/distribution/routers/routing_table.py @@ -5,12 +5,56 @@ # the root directory of this source tree. -from typing import Any, Dict +from typing import Any, Dict, List + +from llama_stack.distribution.datatypes import ( + Api, + GenericProviderConfig, + ProviderRoutingEntry, +) +from llama_stack.distribution.distribution import api_providers +from llama_stack.distribution.utils.dynamic import instantiate_provider +from termcolor import cprint class RoutingTable: - def __init__(self, provider_routing_table: Dict[str, Any]): + def __init__(self, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]): self.provider_routing_table = provider_routing_table + # map {api: {routing_key: impl}}, e.g. {'inference': {'8b': , '70b': }} + self.api2routes = {} - def print(self): - print(f"ROUTING TABLE {self.provider_routing_table}") + async def initialize(self, api_str: str) -> None: + """Initialize the routing table with concrete provider impls""" + if api_str not in self.provider_routing_table: + raise ValueError(f"API {api_str} not found in routing table") + + providers = api_providers()[Api(api_str)] + routing_list = self.provider_routing_table[api_str] + + self.api2routes[api_str] = {} + for rt_entry in routing_list: + rt_key = rt_entry.routing_key + provider_id = rt_entry.provider_id + impl = await instantiate_provider( + providers[provider_id], + deps=[], + provider_config=GenericProviderConfig( + provider_id=provider_id, config=rt_entry.config + ), + ) + cprint(f"impl = {impl}", "red") + self.api2routes[api_str][rt_key] = impl + + cprint(f"> Initialized implementations for {api_str} in routing table", "blue") + + async def shutdown(self, api_str: str) -> None: + """Shutdown the routing table""" + if api_str not in self.api2routes: + return + + for impl in self.api2routes[api_str].values(): + await impl.shutdown() + + def get_provider_impl(self, api: str, routing_key: str) -> Any: + """Get the provider impl for a given api and routing key""" + return self.api2routes[api][routing_key] diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index 74a82bebc..4264b8984 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -21,18 +21,22 @@ provider_routing_table: max_seq_len: 4096 max_batch_size: 1 - routing_key: Meta-Llama3.1-8B - provider_id: remote::ollama - config: - url: http:ollama-url-1.com - memory: - - routing_key: keyvalue - provider_id: remote::pgvector - config: - host: localhost - port: 5432 - db: vectordb - user: vectoruser - password: xxxx - - routing_key: vector provider_id: meta-reference - config: {} + config: + model: Meta-Llama3.1-8B + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + # memory: + # - routing_key: keyvalue + # provider_id: remote::pgvector + # config: + # host: localhost + # port: 5432 + # db: vectordb + # user: vectoruser + # password: xxxx + # - routing_key: vector + # provider_id: meta-reference + # config: {}