mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
skeleton unified routing table, api routers
This commit is contained in:
parent
2dc14cba2c
commit
85d927adde
11 changed files with 210 additions and 231 deletions
|
@ -10,10 +10,10 @@ from typing import Any, AsyncGenerator
|
|||
|
||||
import fire
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from .event_logger import EventLogger
|
||||
|
||||
|
@ -104,11 +104,9 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
iterator = client.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model="Meta-Llama3.1-8B-Instruct",
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
)
|
||||
model="Meta-Llama3.1-8B-Instruct",
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
)
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
|
|
@ -43,6 +43,16 @@ class ProviderSpec(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_id: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderRoutingEntry(GenericProviderConfig):
|
||||
routing_key: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouterProviderSpec(ProviderSpec):
|
||||
provider_id: str = "router"
|
||||
|
@ -50,14 +60,20 @@ class RouterProviderSpec(ProviderSpec):
|
|||
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
routing_table: List[ProviderRoutingEntry] = Field(
|
||||
default_factory=list,
|
||||
description="Routing table entries corresponding to the API",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -65,11 +81,6 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
raise AssertionError("Should not be called on RouterProviderSpec")
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_id: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_id: str = Field(
|
||||
|
@ -204,12 +215,7 @@ in the runtime configuration to help route to the correct provider.""",
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderRoutingEntry(GenericProviderConfig):
|
||||
routing_key: str
|
||||
|
||||
|
||||
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
|
||||
ProviderMapEntry = GenericProviderConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -248,6 +254,21 @@ As examples:
|
|||
|
||||
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
|
||||
)
|
||||
provider_routing_table: Dict[str, List[ProviderRoutingEntry]] = Field(
|
||||
description="""
|
||||
API: List[ProviderRoutingEntry] map. Each ProviderRoutingEntry is a (routing_key, provider_config) tuple.
|
||||
|
||||
E.g. The following is a ProviderRoutingEntry for inference API:
|
||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
||||
provider_id: meta-reference
|
||||
config:
|
||||
model: Meta-Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
27
llama_stack/distribution/routers/__init__.py
Normal file
27
llama_stack/distribution/routers/__init__.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
||||
async def get_router_impl(api: str, provider_routing_table: Dict[str, Any]):
|
||||
from .routers import InferenceRouter, MemoryRouter
|
||||
from .routing_table import RoutingTable
|
||||
|
||||
api2routers = {
|
||||
"memory": MemoryRouter,
|
||||
"inference": InferenceRouter,
|
||||
}
|
||||
|
||||
routing_table = RoutingTable(provider_routing_table)
|
||||
routing_table.print()
|
||||
|
||||
impl = api2routers[api](routing_table)
|
||||
# impl = Router(api, provider_routing_table)
|
||||
await impl.initialize()
|
||||
return impl
|
84
llama_stack/distribution/routers/routers.py
Normal file
84
llama_stack/distribution/routers/routers.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .routing_table import RoutingTable
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
"""Routes to an provider based on the memory bank type"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
print("MemoryRouter: create_memory_bank")
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
print("MemoryRouter: get_memory_bank")
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
print("MemoryRouter: insert_documents")
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
print("query_documents")
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = list,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
print("Inference Router: chat_completion")
|
16
llama_stack/distribution/routers/routing_table.py
Normal file
16
llama_stack/distribution/routers/routing_table.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class RoutingTable:
|
||||
def __init__(self, provider_routing_table: Dict[str, Any]):
|
||||
self.provider_routing_table = provider_routing_table
|
||||
|
||||
def print(self):
|
||||
print(f"ROUTING TABLE {self.provider_routing_table}")
|
|
@ -50,7 +50,10 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
|||
|
||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
from llama_stack.distribution.utils.dynamic import (
|
||||
instantiate_provider,
|
||||
instantiate_router,
|
||||
)
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
|
@ -288,8 +291,8 @@ def snake_to_camel(snake_str):
|
|||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||
|
||||
|
||||
async def resolve_impls(
|
||||
provider_map: Dict[str, ProviderMapEntry],
|
||||
async def resolve_impls_with_routing(
|
||||
stack_run_config: StackRunConfig,
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
|
@ -297,33 +300,28 @@ async def resolve_impls(
|
|||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_providers = api_providers()
|
||||
|
||||
specs = {}
|
||||
for api_str, item in provider_map.items():
|
||||
|
||||
for api_str in stack_run_config.apis_to_serve:
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
|
||||
if isinstance(item, GenericProviderConfig):
|
||||
if item.provider_id not in providers:
|
||||
# check for regular providers without routing
|
||||
if api_str in stack_run_config.provider_map:
|
||||
provider_map_entry = stack_run_config.provider_map[api_str]
|
||||
if provider_map_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[item.provider_id]
|
||||
else:
|
||||
assert isinstance(item, list)
|
||||
inner_specs = []
|
||||
for rt_entry in item:
|
||||
if rt_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_id])
|
||||
specs[api] = providers[provider_map_entry.provider_id]
|
||||
|
||||
# check for routing table, we need to pass routing table to the router implementation
|
||||
if api_str in stack_run_config.provider_routing_table:
|
||||
specs[api] = RouterProviderSpec(
|
||||
api=api,
|
||||
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
||||
module=f"llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
inner_specs=inner_specs,
|
||||
routing_table=stack_run_config.provider_routing_table[api_str],
|
||||
)
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
@ -331,9 +329,16 @@ async def resolve_impls(
|
|||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
impl = await instantiate_provider(spec, deps, provider_map[api.value])
|
||||
if api.value in stack_run_config.provider_map:
|
||||
provider_config = stack_run_config.provider_map[api.value]
|
||||
impl = await instantiate_provider(spec, deps, provider_config)
|
||||
elif api.value in stack_run_config.provider_routing_table:
|
||||
impl = await instantiate_router(
|
||||
spec, api.value, stack_run_config.provider_routing_table
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Cannot find provider_config for Api {api.value}")
|
||||
impls[api] = impl
|
||||
|
||||
return impls, specs
|
||||
|
@ -345,7 +350,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
|
|
|
@ -16,6 +16,19 @@ def instantiate_class_type(fully_qualified_name):
|
|||
return getattr(module, class_name)
|
||||
|
||||
|
||||
async def instantiate_router(
|
||||
provider_spec: RouterProviderSpec,
|
||||
api: str,
|
||||
provider_routing_table: Dict[str, Any],
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
fn = getattr(module, "get_router_impl")
|
||||
impl = await fn(api, provider_routing_table)
|
||||
impl.__provider_spec__ = provider_spec
|
||||
return impl
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
|
@ -35,22 +48,6 @@ async def instantiate_provider(
|
|||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, RouterProviderSpec):
|
||||
method = "get_router_impl"
|
||||
|
||||
assert isinstance(provider_config, list)
|
||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in provider_config:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_id],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [inner_impls, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ conda_env: local
|
|||
apis_to_serve:
|
||||
- inference
|
||||
# - memory
|
||||
# - telemetry
|
||||
- telemetry
|
||||
provider_map:
|
||||
telemetry:
|
||||
provider_id: meta-reference
|
||||
|
@ -36,60 +36,3 @@ provider_routing_table:
|
|||
- routing_key: vector
|
||||
provider_id: meta-reference
|
||||
config: {}
|
||||
|
||||
|
||||
|
||||
# safety:
|
||||
# provider_id: meta-reference
|
||||
# config:
|
||||
# llama_guard_shield:
|
||||
# model: Llama-Guard-3-8B
|
||||
# excluded_categories: []
|
||||
# disable_input_check: false
|
||||
# disable_output_check: false
|
||||
# prompt_guard_shield:
|
||||
# model: Prompt-Guard-86M
|
||||
# telemetry:
|
||||
# provider_id: meta-reference
|
||||
# config: {}
|
||||
# agents:
|
||||
# provider_id: meta-reference
|
||||
# config: {}
|
||||
# memory:
|
||||
# provider_id: meta-reference
|
||||
# config: {}
|
||||
# models:
|
||||
# provider_id: builtin
|
||||
# config:
|
||||
# models_config:
|
||||
# - core_model_id: Meta-Llama3.1-8B-Instruct
|
||||
# provider_id: meta-reference
|
||||
# api: inference
|
||||
# config:
|
||||
# model: Meta-Llama3.1-8B-Instruct
|
||||
# quantization: null
|
||||
# torch_seed: null
|
||||
# max_seq_len: 4096
|
||||
# max_batch_size: 1
|
||||
# - core_model_id: Meta-Llama3.1-8B
|
||||
# provider_id: meta-reference
|
||||
# api: inference
|
||||
# config:
|
||||
# model: Meta-Llama3.1-8B
|
||||
# quantization: null
|
||||
# torch_seed: null
|
||||
# max_seq_len: 4096
|
||||
# max_batch_size: 1
|
||||
# - core_model_id: Llama-Guard-3-8B
|
||||
# provider_id: meta-reference
|
||||
# api: safety
|
||||
# config:
|
||||
# model: Llama-Guard-3-8B
|
||||
# excluded_categories: []
|
||||
# disable_input_check: false
|
||||
# disable_output_check: false
|
||||
# - core_model_id: Prompt-Guard-86M
|
||||
# provider_id: meta-reference
|
||||
# api: safety
|
||||
# config:
|
||||
# model: Prompt-Guard-86M
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
||||
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
|
||||
from .memory import MemoryRouterImpl
|
||||
|
||||
impl = MemoryRouterImpl(inner_impls, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,91 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryRouterImpl(Memory):
|
||||
"""Routes to an provider based on the memory bank type"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
deps: List[Api],
|
||||
) -> None:
|
||||
self.deps = deps
|
||||
|
||||
bank_types = [v.value for v in MemoryBankType]
|
||||
|
||||
self.providers = {}
|
||||
for routing_key, provider_impl in inner_impls:
|
||||
if routing_key not in bank_types:
|
||||
raise ValueError(
|
||||
f"Unknown routing key `{routing_key}` for memory bank type"
|
||||
)
|
||||
self.providers[routing_key] = provider_impl
|
||||
|
||||
self.bank_id_to_type = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.providers.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider(self, bank_type):
|
||||
if bank_type not in self.providers:
|
||||
raise ValueError(f"Memory bank type {bank_type} not supported")
|
||||
|
||||
return self.providers[bank_type]
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
provider = self.get_provider(config.type)
|
||||
bank = await provider.create_memory_bank(name, config, url)
|
||||
self.bank_id_to_type[bank.bank_id] = config.type
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.get_provider(bank_type)
|
||||
return await provider.get_memory_bank(bank_id)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.get_provider(bank_type)
|
||||
return await provider.insert_documents(bank_id, documents, ttl_seconds)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.get_provider(bank_type)
|
||||
return await provider.query_documents(bank_id, query, params)
|
Loading…
Add table
Add a link
Reference in a new issue