mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 04:45:44 +00:00
Further generalize Xi's changes (#88)
* Further generalize Xi's changes - introduce a slightly more general notion of an AutoRouted provider - the AutoRouted provider is associated with a RoutingTable provider - e.g. inference -> models - Introduced safety -> shields and memory -> memory_banks correspondences * typo * Basic build and run succeeded
This commit is contained in:
parent
b8914bb56f
commit
c1ab66f1e6
21 changed files with 597 additions and 418 deletions
|
@ -4,25 +4,47 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderRoutingEntry
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
async def get_router_impl(
|
||||
api: str, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]
|
||||
):
|
||||
from .routers import InferenceRouter, MemoryRouter
|
||||
from .routing_table import RoutingTable
|
||||
async def get_routing_table_impl(
|
||||
api: Api,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: RoutingTableConfig,
|
||||
_deps,
|
||||
) -> Any:
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
api2routers = {
|
||||
"memory": MemoryRouter,
|
||||
"inference": InferenceRouter,
|
||||
api_to_tables = {
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
}
|
||||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
# initialize routing table with concrete provider impls
|
||||
routing_table = RoutingTable(provider_routing_table)
|
||||
|
||||
impl = api2routers[api](routing_table)
|
||||
impl = api_to_tables[api.value](inner_impls, routing_table_config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
|
||||
|
||||
api_to_routers = {
|
||||
"memory": MemoryRouter,
|
||||
"inference": InferenceRouter,
|
||||
"safety": SafetyRouter,
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,17 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.datatypes import RoutingTable
|
||||
|
||||
from .routing_table import RoutingTable
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from types import MethodType
|
||||
|
||||
from termcolor import cprint
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
|
@ -24,22 +20,24 @@ class MemoryRouter(Memory):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.api = Api.memory.value
|
||||
self.routing_table = routing_table
|
||||
self.bank_id_to_type = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await self.routing_table.initialize(self.api)
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self.routing_table.shutdown(self.api)
|
||||
pass
|
||||
|
||||
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
return self.routing_table.get_provider_impl(self.api, bank_type)
|
||||
provider = self.routing_table.get_provider_impl(bank_type)
|
||||
if not provider:
|
||||
raise ValueError(f"Could not find provider for {bank_type}")
|
||||
return provider
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
|
@ -48,14 +46,15 @@ class MemoryRouter(Memory):
|
|||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_type = config.type
|
||||
bank = await self.routing_table.get_provider_impl(
|
||||
self.api, bank_type
|
||||
provider = await self.routing_table.get_provider_impl(
|
||||
bank_type
|
||||
).create_memory_bank(name, config, url)
|
||||
self.bank_id_to_type[bank.bank_id] = bank_type
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
return await self.get_provider_from_bank_id(bank_id).get_memory_bank(bank_id)
|
||||
provider = self.get_provider_from_bank_id(bank_id)
|
||||
return await provider.get_memory_bank(bank_id)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -85,34 +84,31 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.api = Api.inference.value
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await self.routing_table.initialize(self.api)
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self.routing_table.shutdown(self.api)
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = [],
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(
|
||||
self.api, model
|
||||
).chat_completion(
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
|
@ -128,7 +124,7 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
return await self.routing_table.get_provider_impl(self.api, model).completion(
|
||||
return await self.routing_table.get_provider_impl(model).completion(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
|
@ -141,7 +137,33 @@ class InferenceRouter(Inference):
|
|||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
return await self.routing_table.get_provider_impl(self.api, model).embeddings(
|
||||
return await self.routing_table.get_provider_impl(model).embeddings(
|
||||
model=model,
|
||||
contents=contents,
|
||||
)
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
||||
shield_type=shield_type,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
Api,
|
||||
GenericProviderConfig,
|
||||
ProviderRoutingEntry,
|
||||
)
|
||||
from llama_stack.distribution.distribution import api_providers
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class RoutingTable:
|
||||
def __init__(self, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]):
|
||||
self.provider_routing_table = provider_routing_table
|
||||
# map {api: {routing_key: impl}}, e.g. {'inference': {'8b': <MetaReferenceImpl>, '70b': <OllamaImpl>}}
|
||||
self.api2routes = {}
|
||||
|
||||
async def initialize(self, api_str: str) -> None:
|
||||
"""Initialize the routing table with concrete provider impls"""
|
||||
if api_str not in self.provider_routing_table:
|
||||
raise ValueError(f"API {api_str} not found in routing table")
|
||||
|
||||
providers = api_providers()[Api(api_str)]
|
||||
routing_list = self.provider_routing_table[api_str]
|
||||
|
||||
self.api2routes[api_str] = {}
|
||||
for rt_entry in routing_list:
|
||||
rt_key = rt_entry.routing_key
|
||||
provider_id = rt_entry.provider_id
|
||||
impl = await instantiate_provider(
|
||||
providers[provider_id],
|
||||
deps=[],
|
||||
provider_config=GenericProviderConfig(
|
||||
provider_id=provider_id, config=rt_entry.config
|
||||
),
|
||||
)
|
||||
cprint(f"impl = {impl}", "red")
|
||||
self.api2routes[api_str][rt_key] = impl
|
||||
|
||||
cprint(f"> Initialized implementations for {api_str} in routing table", "blue")
|
||||
|
||||
async def shutdown(self, api_str: str) -> None:
|
||||
"""Shutdown the routing table"""
|
||||
if api_str not in self.api2routes:
|
||||
return
|
||||
|
||||
for impl in self.api2routes[api_str].values():
|
||||
await impl.shutdown()
|
||||
|
||||
def get_provider_impl(self, api: str, routing_key: str) -> Any:
|
||||
"""Get the provider impl for a given api and routing key"""
|
||||
return self.api2routes[api][routing_key]
|
118
llama_stack/distribution/routers/routing_tables.py
Normal file
118
llama_stack/distribution/routers/routing_tables.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: RoutingTableConfig,
|
||||
) -> None:
|
||||
self.providers = {k: v for k, v in inner_impls}
|
||||
self.routing_keys = list(self.providers.keys())
|
||||
self.routing_table_config = routing_table_config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.providers.values():
|
||||
await p.shutdown()
|
||||
|
||||
async def get_provider_impl(self, routing_key: str) -> Optional[Any]:
|
||||
return self.providers.get(routing_key)
|
||||
|
||||
async def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
|
||||
async def get_provider_config(
|
||||
self, routing_key: str
|
||||
) -> Optional[GenericProviderConfig]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
if entry.routing_key == routing_key:
|
||||
return entry
|
||||
return None
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
||||
async def list_models(self) -> List[ModelServingSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config.entries:
|
||||
model_id = entry.routing_key
|
||||
specs.append(
|
||||
ModelServingSpec(
|
||||
llama_model=resolve_model(model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
if entry.routing_key == core_model_id:
|
||||
return ModelServingSpec(
|
||||
llama_model=resolve_model(core_model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
|
||||
async def list_shields(self) -> List[ShieldSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config.entries:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
if entry.routing_key == shield_type:
|
||||
return ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config.entries:
|
||||
specs.append(
|
||||
MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
if entry.routing_key == bank_type:
|
||||
return MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
Loading…
Add table
Add a link
Reference in a new issue