mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
router table registration works
This commit is contained in:
parent
85d927adde
commit
951cc9d7b7
4 changed files with 91 additions and 27 deletions
|
@ -6,10 +6,12 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api, ProviderRoutingEntry
|
||||||
|
|
||||||
|
|
||||||
async def get_router_impl(api: str, provider_routing_table: Dict[str, Any]):
|
async def get_router_impl(
|
||||||
|
api: str, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]
|
||||||
|
):
|
||||||
from .routers import InferenceRouter, MemoryRouter
|
from .routers import InferenceRouter, MemoryRouter
|
||||||
from .routing_table import RoutingTable
|
from .routing_table import RoutingTable
|
||||||
|
|
||||||
|
@ -18,10 +20,9 @@ async def get_router_impl(api: str, provider_routing_table: Dict[str, Any]):
|
||||||
"inference": InferenceRouter,
|
"inference": InferenceRouter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# initialize routing table with concrete provider impls
|
||||||
routing_table = RoutingTable(provider_routing_table)
|
routing_table = RoutingTable(provider_routing_table)
|
||||||
routing_table.print()
|
|
||||||
|
|
||||||
impl = api2routers[api](routing_table)
|
impl = api2routers[api](routing_table)
|
||||||
# impl = Router(api, provider_routing_table)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, AsyncGenerator, Dict, List, Tuple
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
@ -57,17 +57,20 @@ class MemoryRouter(Memory):
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
"""Routes to an provider based on the model"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.api = Api.inference.value
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
await self.routing_table.initialize(self.api)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
await self.routing_table.shutdown(self.api)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -80,5 +83,17 @@ class InferenceRouter(Inference):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
) -> AsyncGenerator:
|
||||||
print("Inference Router: chat_completion")
|
provider_impl = self.routing_table.get_provider_impl(self.api, model)
|
||||||
|
print("InferenceRouter: chat_completion", provider_impl)
|
||||||
|
async for chunk in provider_impl.chat_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
|
@ -5,12 +5,56 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
Api,
|
||||||
|
GenericProviderConfig,
|
||||||
|
ProviderRoutingEntry,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.distribution import api_providers
|
||||||
|
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
class RoutingTable:
|
class RoutingTable:
|
||||||
def __init__(self, provider_routing_table: Dict[str, Any]):
|
def __init__(self, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]):
|
||||||
self.provider_routing_table = provider_routing_table
|
self.provider_routing_table = provider_routing_table
|
||||||
|
# map {api: {routing_key: impl}}, e.g. {'inference': {'8b': <MetaReferenceImpl>, '70b': <OllamaImpl>}}
|
||||||
|
self.api2routes = {}
|
||||||
|
|
||||||
def print(self):
|
async def initialize(self, api_str: str) -> None:
|
||||||
print(f"ROUTING TABLE {self.provider_routing_table}")
|
"""Initialize the routing table with concrete provider impls"""
|
||||||
|
if api_str not in self.provider_routing_table:
|
||||||
|
raise ValueError(f"API {api_str} not found in routing table")
|
||||||
|
|
||||||
|
providers = api_providers()[Api(api_str)]
|
||||||
|
routing_list = self.provider_routing_table[api_str]
|
||||||
|
|
||||||
|
self.api2routes[api_str] = {}
|
||||||
|
for rt_entry in routing_list:
|
||||||
|
rt_key = rt_entry.routing_key
|
||||||
|
provider_id = rt_entry.provider_id
|
||||||
|
impl = await instantiate_provider(
|
||||||
|
providers[provider_id],
|
||||||
|
deps=[],
|
||||||
|
provider_config=GenericProviderConfig(
|
||||||
|
provider_id=provider_id, config=rt_entry.config
|
||||||
|
),
|
||||||
|
)
|
||||||
|
cprint(f"impl = {impl}", "red")
|
||||||
|
self.api2routes[api_str][rt_key] = impl
|
||||||
|
|
||||||
|
cprint(f"> Initialized implementations for {api_str} in routing table", "blue")
|
||||||
|
|
||||||
|
async def shutdown(self, api_str: str) -> None:
|
||||||
|
"""Shutdown the routing table"""
|
||||||
|
if api_str not in self.api2routes:
|
||||||
|
return
|
||||||
|
|
||||||
|
for impl in self.api2routes[api_str].values():
|
||||||
|
await impl.shutdown()
|
||||||
|
|
||||||
|
def get_provider_impl(self, api: str, routing_key: str) -> Any:
|
||||||
|
"""Get the provider impl for a given api and routing key"""
|
||||||
|
return self.api2routes[api][routing_key]
|
||||||
|
|
|
@ -21,18 +21,22 @@ provider_routing_table:
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
max_batch_size: 1
|
max_batch_size: 1
|
||||||
- routing_key: Meta-Llama3.1-8B
|
- routing_key: Meta-Llama3.1-8B
|
||||||
provider_id: remote::ollama
|
|
||||||
config:
|
|
||||||
url: http:ollama-url-1.com
|
|
||||||
memory:
|
|
||||||
- routing_key: keyvalue
|
|
||||||
provider_id: remote::pgvector
|
|
||||||
config:
|
|
||||||
host: localhost
|
|
||||||
port: 5432
|
|
||||||
db: vectordb
|
|
||||||
user: vectoruser
|
|
||||||
password: xxxx
|
|
||||||
- routing_key: vector
|
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
config: {}
|
config:
|
||||||
|
model: Meta-Llama3.1-8B
|
||||||
|
quantization: null
|
||||||
|
torch_seed: null
|
||||||
|
max_seq_len: 4096
|
||||||
|
max_batch_size: 1
|
||||||
|
# memory:
|
||||||
|
# - routing_key: keyvalue
|
||||||
|
# provider_id: remote::pgvector
|
||||||
|
# config:
|
||||||
|
# host: localhost
|
||||||
|
# port: 5432
|
||||||
|
# db: vectordb
|
||||||
|
# user: vectoruser
|
||||||
|
# password: xxxx
|
||||||
|
# - routing_key: vector
|
||||||
|
# provider_id: meta-reference
|
||||||
|
# config: {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue