router table registration works

This commit is contained in:
Xi Yan 2024-09-21 14:26:48 -07:00
parent 85d927adde
commit 951cc9d7b7
4 changed files with 91 additions and 27 deletions

View file

@ -6,10 +6,12 @@
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api, ProviderRoutingEntry
async def get_router_impl(api: str, provider_routing_table: Dict[str, Any]): async def get_router_impl(
api: str, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]
):
from .routers import InferenceRouter, MemoryRouter from .routers import InferenceRouter, MemoryRouter
from .routing_table import RoutingTable from .routing_table import RoutingTable
@ -18,10 +20,9 @@ async def get_router_impl(api: str, provider_routing_table: Dict[str, Any]):
"inference": InferenceRouter, "inference": InferenceRouter,
} }
# initialize routing table with concrete provider impls
routing_table = RoutingTable(provider_routing_table) routing_table = RoutingTable(provider_routing_table)
routing_table.print()
impl = api2routers[api](routing_table) impl = api2routers[api](routing_table)
# impl = Router(api, provider_routing_table)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Tuple from typing import Any, AsyncGenerator, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
@ -57,17 +57,20 @@ class MemoryRouter(Memory):
class InferenceRouter(Inference): class InferenceRouter(Inference):
"""Routes to an provider based on the model"""
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
self.api = Api.inference.value
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
pass await self.routing_table.initialize(self.api)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass await self.routing_table.shutdown(self.api)
async def chat_completion( async def chat_completion(
self, self,
@ -80,5 +83,17 @@ class InferenceRouter(Inference):
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ) -> AsyncGenerator:
print("Inference Router: chat_completion") provider_impl = self.routing_table.get_provider_impl(self.api, model)
print("InferenceRouter: chat_completion", provider_impl)
async for chunk in provider_impl.chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
):
yield chunk

View file

@ -5,12 +5,56 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict from typing import Any, Dict, List
from llama_stack.distribution.datatypes import (
Api,
GenericProviderConfig,
ProviderRoutingEntry,
)
from llama_stack.distribution.distribution import api_providers
from llama_stack.distribution.utils.dynamic import instantiate_provider
from termcolor import cprint
class RoutingTable: class RoutingTable:
def __init__(self, provider_routing_table: Dict[str, Any]): def __init__(self, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]):
self.provider_routing_table = provider_routing_table self.provider_routing_table = provider_routing_table
# map {api: {routing_key: impl}}, e.g. {'inference': {'8b': <MetaReferenceImpl>, '70b': <OllamaImpl>}}
self.api2routes = {}
def print(self): async def initialize(self, api_str: str) -> None:
print(f"ROUTING TABLE {self.provider_routing_table}") """Initialize the routing table with concrete provider impls"""
if api_str not in self.provider_routing_table:
raise ValueError(f"API {api_str} not found in routing table")
providers = api_providers()[Api(api_str)]
routing_list = self.provider_routing_table[api_str]
self.api2routes[api_str] = {}
for rt_entry in routing_list:
rt_key = rt_entry.routing_key
provider_id = rt_entry.provider_id
impl = await instantiate_provider(
providers[provider_id],
deps=[],
provider_config=GenericProviderConfig(
provider_id=provider_id, config=rt_entry.config
),
)
cprint(f"impl = {impl}", "red")
self.api2routes[api_str][rt_key] = impl
cprint(f"> Initialized implementations for {api_str} in routing table", "blue")
async def shutdown(self, api_str: str) -> None:
"""Shutdown the routing table"""
if api_str not in self.api2routes:
return
for impl in self.api2routes[api_str].values():
await impl.shutdown()
def get_provider_impl(self, api: str, routing_key: str) -> Any:
"""Get the provider impl for a given api and routing key"""
return self.api2routes[api][routing_key]

View file

@ -21,18 +21,22 @@ provider_routing_table:
max_seq_len: 4096 max_seq_len: 4096
max_batch_size: 1 max_batch_size: 1
- routing_key: Meta-Llama3.1-8B - routing_key: Meta-Llama3.1-8B
provider_id: remote::ollama
config:
url: http:ollama-url-1.com
memory:
- routing_key: keyvalue
provider_id: remote::pgvector
config:
host: localhost
port: 5432
db: vectordb
user: vectoruser
password: xxxx
- routing_key: vector
provider_id: meta-reference provider_id: meta-reference
config: {} config:
model: Meta-Llama3.1-8B
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
# memory:
# - routing_key: keyvalue
# provider_id: remote::pgvector
# config:
# host: localhost
# port: 5432
# db: vectordb
# user: vectoruser
# password: xxxx
# - routing_key: vector
# provider_id: meta-reference
# config: {}