Further generalize Xi's changes (#88)

* Further generalize Xi's changes

- introduce a slightly more general notion of an AutoRouted provider
- the AutoRouted provider is associated with a RoutingTable provider
- e.g. inference -> models
- Introduced safety -> shields and memory -> memory_banks
  correspondences

* typo

* Basic build and run succeeded
This commit is contained in:
Ashwin Bharambe 2024-09-22 16:31:18 -07:00 committed by GitHub
parent b8914bb56f
commit c1ab66f1e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 597 additions and 418 deletions

View file

@ -4,25 +4,47 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple
from llama_stack.distribution.datatypes import Api, ProviderRoutingEntry
from llama_stack.distribution.datatypes import * # noqa: F403
async def get_router_impl(
api: str, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]
):
from .routers import InferenceRouter, MemoryRouter
from .routing_table import RoutingTable
async def get_routing_table_impl(
api: Api,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
_deps,
) -> Any:
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
ShieldsRoutingTable,
)
api2routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
api_to_tables = {
"memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
# initialize routing table with concrete provider impls
routing_table = RoutingTable(provider_routing_table)
impl = api2routers[api](routing_table)
impl = api_to_tables[api.value](inner_impls, routing_table_config)
await impl.initialize()
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
api_to_routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table)
await impl.initialize()
return impl

View file

@ -4,17 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Tuple
from typing import Any, AsyncGenerator, Dict, List
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import RoutingTable
from .routing_table import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from types import MethodType
from termcolor import cprint
from llama_stack.apis.safety import * # noqa: F403
class MemoryRouter(Memory):
@ -24,22 +20,24 @@ class MemoryRouter(Memory):
self,
routing_table: RoutingTable,
) -> None:
self.api = Api.memory.value
self.routing_table = routing_table
self.bank_id_to_type = {}
async def initialize(self) -> None:
await self.routing_table.initialize(self.api)
pass
async def shutdown(self) -> None:
await self.routing_table.shutdown(self.api)
pass
def get_provider_from_bank_id(self, bank_id: str) -> Any:
bank_type = self.bank_id_to_type.get(bank_id)
if not bank_type:
raise ValueError(f"Could not find bank type for {bank_id}")
return self.routing_table.get_provider_impl(self.api, bank_type)
provider = self.routing_table.get_provider_impl(bank_type)
if not provider:
raise ValueError(f"Could not find provider for {bank_type}")
return provider
async def create_memory_bank(
self,
@ -48,14 +46,15 @@ class MemoryRouter(Memory):
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
bank = await self.routing_table.get_provider_impl(
self.api, bank_type
provider = await self.routing_table.get_provider_impl(
bank_type
).create_memory_bank(name, config, url)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
return await self.get_provider_from_bank_id(bank_id).get_memory_bank(bank_id)
provider = self.get_provider_from_bank_id(bank_id)
return await provider.get_memory_bank(bank_id)
async def insert_documents(
self,
@ -85,34 +84,31 @@ class InferenceRouter(Inference):
self,
routing_table: RoutingTable,
) -> None:
self.api = Api.inference.value
self.routing_table = routing_table
async def initialize(self) -> None:
await self.routing_table.initialize(self.api)
pass
async def shutdown(self) -> None:
await self.routing_table.shutdown(self.api)
pass
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = [],
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(
self.api, model
).chat_completion(
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
@ -128,7 +124,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(self.api, model).completion(
return await self.routing_table.get_provider_impl(model).completion(
model=model,
content=content,
sampling_params=sampling_params,
@ -141,7 +137,33 @@ class InferenceRouter(Inference):
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(self.api, model).embeddings(
return await self.routing_table.get_provider_impl(model).embeddings(
model=model,
contents=contents,
)
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(shield_type).run_shield(
shield_type=shield_type,
messages=messages,
params=params,
)

View file

@ -1,60 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_stack.distribution.datatypes import (
Api,
GenericProviderConfig,
ProviderRoutingEntry,
)
from llama_stack.distribution.distribution import api_providers
from llama_stack.distribution.utils.dynamic import instantiate_provider
from termcolor import cprint
class RoutingTable:
def __init__(self, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]):
self.provider_routing_table = provider_routing_table
# map {api: {routing_key: impl}}, e.g. {'inference': {'8b': <MetaReferenceImpl>, '70b': <OllamaImpl>}}
self.api2routes = {}
async def initialize(self, api_str: str) -> None:
"""Initialize the routing table with concrete provider impls"""
if api_str not in self.provider_routing_table:
raise ValueError(f"API {api_str} not found in routing table")
providers = api_providers()[Api(api_str)]
routing_list = self.provider_routing_table[api_str]
self.api2routes[api_str] = {}
for rt_entry in routing_list:
rt_key = rt_entry.routing_key
provider_id = rt_entry.provider_id
impl = await instantiate_provider(
providers[provider_id],
deps=[],
provider_config=GenericProviderConfig(
provider_id=provider_id, config=rt_entry.config
),
)
cprint(f"impl = {impl}", "red")
self.api2routes[api_str][rt_key] = impl
cprint(f"> Initialized implementations for {api_str} in routing table", "blue")
async def shutdown(self, api_str: str) -> None:
"""Shutdown the routing table"""
if api_str not in self.api2routes:
return
for impl in self.api2routes[api_str].values():
await impl.shutdown()
def get_provider_impl(self, api: str, routing_key: str) -> Any:
"""Get the provider impl for a given api and routing key"""
return self.api2routes[api][routing_key]

View file

@ -0,0 +1,118 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Optional, Tuple
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
) -> None:
self.providers = {k: v for k, v in inner_impls}
self.routing_keys = list(self.providers.keys())
self.routing_table_config = routing_table_config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for p in self.providers.values():
await p.shutdown()
async def get_provider_impl(self, routing_key: str) -> Optional[Any]:
return self.providers.get(routing_key)
async def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def get_provider_config(
self, routing_key: str
) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config.entries:
if entry.routing_key == routing_key:
return entry
return None
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelServingSpec]:
specs = []
for entry in self.routing_table_config.entries:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=entry,
)
)
return specs
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config.entries:
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
provider_config=entry,
)
return None
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config.entries:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config.entries:
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
return None
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankSpec]:
specs = []
for entry in self.routing_table_config.entries:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
)
return specs
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config.entries:
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,
provider_config=entry,
)
return None