skeleton unified routing table, api routers

This commit is contained in:
Xi Yan 2024-09-21 13:44:33 -07:00
parent 2dc14cba2c
commit 85d927adde
11 changed files with 210 additions and 231 deletions

View file

@ -10,10 +10,10 @@ from typing import Any, AsyncGenerator
import fire
import httpx
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from .event_logger import EventLogger
@ -104,12 +104,10 @@ async def run_main(host: str, port: int, stream: bool):
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
ChatCompletionRequest(
model="Meta-Llama3.1-8B-Instruct",
messages=[message],
stream=stream,
)
)
async for log in EventLogger().log(iterator):
log.print()

View file

@ -43,6 +43,16 @@ class ProviderSpec(BaseModel):
)
class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]
@json_schema_type
class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str
@json_schema_type
class RouterProviderSpec(ProviderSpec):
provider_id: str = "router"
@ -50,14 +60,20 @@ class RouterProviderSpec(ProviderSpec):
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
routing_table: List[ProviderRoutingEntry] = Field(
default_factory=list,
description="Routing table entries corresponding to the API",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
""",
)
provider_data_validator: Optional[str] = Field(
default=None,
)
@property
@ -65,11 +81,6 @@ Fully-qualified name of the module to import. The module is expected to have:
raise AssertionError("Should not be called on RouterProviderSpec")
class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
@ -204,12 +215,7 @@ in the runtime configuration to help route to the correct provider.""",
)
@json_schema_type
class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
ProviderMapEntry = GenericProviderConfig
@json_schema_type
@ -248,6 +254,21 @@ As examples:
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
)
provider_routing_table: Dict[str, List[ProviderRoutingEntry]] = Field(
description="""
API: List[ProviderRoutingEntry] map. Each ProviderRoutingEntry is a (routing_key, provider_config) tuple.
E.g. The following is a ProviderRoutingEntry for inference API:
- routing_key: Meta-Llama3.1-8B-Instruct
provider_id: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
"""
)
@json_schema_type

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api
async def get_router_impl(api: str, provider_routing_table: Dict[str, Any]):
from .routers import InferenceRouter, MemoryRouter
from .routing_table import RoutingTable
api2routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
}
routing_table = RoutingTable(provider_routing_table)
routing_table.print()
impl = api2routers[api](routing_table)
# impl = Router(api, provider_routing_table)
await impl.initialize()
return impl

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api
from .routing_table import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank type"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
print("MemoryRouter: create_memory_bank")
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
print("MemoryRouter: get_memory_bank")
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
print("MemoryRouter: insert_documents")
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
print("query_documents")
class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
print("Inference Router: chat_completion")

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
class RoutingTable:
def __init__(self, provider_routing_table: Dict[str, Any]):
self.provider_routing_table = provider_routing_table
def print(self):
print(f"ROUTING TABLE {self.provider_routing_table}")

View file

@ -50,7 +50,10 @@ from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider
from llama_stack.distribution.utils.dynamic import (
instantiate_provider,
instantiate_router,
)
def is_async_iterator_type(typ):
@ -288,8 +291,8 @@ def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
async def resolve_impls(
provider_map: Dict[str, ProviderMapEntry],
async def resolve_impls_with_routing(
stack_run_config: StackRunConfig,
) -> Dict[Api, Any]:
"""
Does two things:
@ -297,33 +300,28 @@ async def resolve_impls(
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = api_providers()
specs = {}
for api_str, item in provider_map.items():
for api_str in stack_run_config.apis_to_serve:
api = Api(api_str)
providers = all_providers[api]
if isinstance(item, GenericProviderConfig):
if item.provider_id not in providers:
# check for regular providers without routing
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
if provider_map_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[item.provider_id]
else:
assert isinstance(item, list)
inner_specs = []
for rt_entry in item:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
specs[api] = providers[provider_map_entry.provider_id]
# check for routing table, we need to pass routing table to the router implementation
if api_str in stack_run_config.provider_routing_table:
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.providers.routers.{api.value.lower()}",
module=f"llama_stack.distribution.routers",
api_dependencies=[],
inner_specs=inner_specs,
routing_table=stack_run_config.provider_routing_table[api_str],
)
sorted_specs = topological_sort(specs.values())
@ -331,9 +329,16 @@ async def resolve_impls(
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, provider_map[api.value])
if api.value in stack_run_config.provider_map:
provider_config = stack_run_config.provider_map[api.value]
impl = await instantiate_provider(spec, deps, provider_config)
elif api.value in stack_run_config.provider_routing_table:
impl = await instantiate_router(
spec, api.value, stack_run_config.provider_routing_table
)
else:
raise ValueError(f"Cannot find provider_config for Api {api.value}")
impls[api] = impl
return impls, specs
@ -345,7 +350,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
app = FastAPI()
impls, specs = asyncio.run(resolve_impls(config.provider_map))
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])

View file

@ -16,6 +16,19 @@ def instantiate_class_type(fully_qualified_name):
return getattr(module, class_name)
async def instantiate_router(
provider_spec: RouterProviderSpec,
api: str,
provider_routing_table: Dict[str, Any],
):
module = importlib.import_module(provider_spec.module)
fn = getattr(module, "get_router_impl")
impl = await fn(api, provider_routing_table)
impl.__provider_spec__ = provider_spec
return impl
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
@ -35,22 +48,6 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, RouterProviderSpec):
method = "get_router_impl"
assert isinstance(provider_config, list)
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in provider_config:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [inner_impls, deps]
else:
method = "get_provider_impl"

View file

@ -5,7 +5,7 @@ conda_env: local
apis_to_serve:
- inference
# - memory
# - telemetry
- telemetry
provider_map:
telemetry:
provider_id: meta-reference
@ -36,60 +36,3 @@ provider_routing_table:
- routing_key: vector
provider_id: meta-reference
config: {}
# safety:
# provider_id: meta-reference
# config:
# llama_guard_shield:
# model: Llama-Guard-3-8B
# excluded_categories: []
# disable_input_check: false
# disable_output_check: false
# prompt_guard_shield:
# model: Prompt-Guard-86M
# telemetry:
# provider_id: meta-reference
# config: {}
# agents:
# provider_id: meta-reference
# config: {}
# memory:
# provider_id: meta-reference
# config: {}
# models:
# provider_id: builtin
# config:
# models_config:
# - core_model_id: Meta-Llama3.1-8B-Instruct
# provider_id: meta-reference
# api: inference
# config:
# model: Meta-Llama3.1-8B-Instruct
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
# - core_model_id: Meta-Llama3.1-8B
# provider_id: meta-reference
# api: inference
# config:
# model: Meta-Llama3.1-8B
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
# - core_model_id: Llama-Guard-3-8B
# provider_id: meta-reference
# api: safety
# config:
# model: Llama-Guard-3-8B
# excluded_categories: []
# disable_input_check: false
# disable_output_check: false
# - core_model_id: Prompt-Guard-86M
# provider_id: meta-reference
# api: safety
# config:
# model: Prompt-Guard-86M

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Tuple
from llama_stack.distribution.datatypes import Api
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
from .memory import MemoryRouterImpl
impl = MemoryRouterImpl(inner_impls, deps)
await impl.initialize()
return impl

View file

@ -1,91 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api
from llama_stack.apis.memory import * # noqa: F403
class MemoryRouterImpl(Memory):
"""Routes to an provider based on the memory bank type"""
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
deps: List[Api],
) -> None:
self.deps = deps
bank_types = [v.value for v in MemoryBankType]
self.providers = {}
for routing_key, provider_impl in inner_impls:
if routing_key not in bank_types:
raise ValueError(
f"Unknown routing key `{routing_key}` for memory bank type"
)
self.providers[routing_key] = provider_impl
self.bank_id_to_type = {}
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for p in self.providers.values():
await p.shutdown()
def get_provider(self, bank_type):
if bank_type not in self.providers:
raise ValueError(f"Memory bank type {bank_type} not supported")
return self.providers[bank_type]
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
provider = self.get_provider(config.type)
bank = await provider.create_memory_bank(name, config, url)
self.bank_id_to_type[bank.bank_id] = config.type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.get_memory_bank(bank_id)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.insert_documents(bank_id, documents, ttl_seconds)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
provider = self.get_provider(bank_type)
return await provider.query_documents(bank_id, query, params)