mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
skeleton unified routing table, api routers
This commit is contained in:
parent
2dc14cba2c
commit
85d927adde
11 changed files with 210 additions and 231 deletions
|
@ -10,10 +10,10 @@ from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from .event_logger import EventLogger
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
|
@ -104,11 +104,9 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
)
|
)
|
||||||
cprint(f"User>{message.content}", "green")
|
cprint(f"User>{message.content}", "green")
|
||||||
iterator = client.chat_completion(
|
iterator = client.chat_completion(
|
||||||
ChatCompletionRequest(
|
model="Meta-Llama3.1-8B-Instruct",
|
||||||
model="Meta-Llama3.1-8B-Instruct",
|
messages=[message],
|
||||||
messages=[message],
|
stream=stream,
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
async for log in EventLogger().log(iterator):
|
async for log in EventLogger().log(iterator):
|
||||||
log.print()
|
log.print()
|
||||||
|
|
|
@ -43,6 +43,16 @@ class ProviderSpec(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GenericProviderConfig(BaseModel):
|
||||||
|
provider_id: str
|
||||||
|
config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ProviderRoutingEntry(GenericProviderConfig):
|
||||||
|
routing_key: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RouterProviderSpec(ProviderSpec):
|
class RouterProviderSpec(ProviderSpec):
|
||||||
provider_id: str = "router"
|
provider_id: str = "router"
|
||||||
|
@ -50,14 +60,20 @@ class RouterProviderSpec(ProviderSpec):
|
||||||
|
|
||||||
docker_image: Optional[str] = None
|
docker_image: Optional[str] = None
|
||||||
|
|
||||||
inner_specs: List[ProviderSpec]
|
routing_table: List[ProviderRoutingEntry] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Routing table entries corresponding to the API",
|
||||||
|
)
|
||||||
module: str = Field(
|
module: str = Field(
|
||||||
...,
|
...,
|
||||||
description="""
|
description="""
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||||
""",
|
""",
|
||||||
|
)
|
||||||
|
provider_data_validator: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -65,11 +81,6 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
raise AssertionError("Should not be called on RouterProviderSpec")
|
raise AssertionError("Should not be called on RouterProviderSpec")
|
||||||
|
|
||||||
|
|
||||||
class GenericProviderConfig(BaseModel):
|
|
||||||
provider_id: str
|
|
||||||
config: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
adapter_id: str = Field(
|
adapter_id: str = Field(
|
||||||
|
@ -204,12 +215,7 @@ in the runtime configuration to help route to the correct provider.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
ProviderMapEntry = GenericProviderConfig
|
||||||
class ProviderRoutingEntry(GenericProviderConfig):
|
|
||||||
routing_key: str
|
|
||||||
|
|
||||||
|
|
||||||
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -248,6 +254,21 @@ As examples:
|
||||||
|
|
||||||
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
|
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
|
||||||
)
|
)
|
||||||
|
provider_routing_table: Dict[str, List[ProviderRoutingEntry]] = Field(
|
||||||
|
description="""
|
||||||
|
API: List[ProviderRoutingEntry] map. Each ProviderRoutingEntry is a (routing_key, provider_config) tuple.
|
||||||
|
|
||||||
|
E.g. The following is a ProviderRoutingEntry for inference API:
|
||||||
|
- routing_key: Meta-Llama3.1-8B-Instruct
|
||||||
|
provider_id: meta-reference
|
||||||
|
config:
|
||||||
|
model: Meta-Llama3.1-8B-Instruct
|
||||||
|
quantization: null
|
||||||
|
torch_seed: null
|
||||||
|
max_seq_len: 4096
|
||||||
|
max_batch_size: 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
27
llama_stack/distribution/routers/__init__.py
Normal file
27
llama_stack/distribution/routers/__init__.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
async def get_router_impl(api: str, provider_routing_table: Dict[str, Any]):
|
||||||
|
from .routers import InferenceRouter, MemoryRouter
|
||||||
|
from .routing_table import RoutingTable
|
||||||
|
|
||||||
|
api2routers = {
|
||||||
|
"memory": MemoryRouter,
|
||||||
|
"inference": InferenceRouter,
|
||||||
|
}
|
||||||
|
|
||||||
|
routing_table = RoutingTable(provider_routing_table)
|
||||||
|
routing_table.print()
|
||||||
|
|
||||||
|
impl = api2routers[api](routing_table)
|
||||||
|
# impl = Router(api, provider_routing_table)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
84
llama_stack/distribution/routers/routers.py
Normal file
84
llama_stack/distribution/routers/routers.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
from .routing_table import RoutingTable
|
||||||
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRouter(Memory):
|
||||||
|
"""Routes to an provider based on the memory bank type"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
print("MemoryRouter: create_memory_bank")
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
print("MemoryRouter: get_memory_bank")
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
print("MemoryRouter: insert_documents")
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
print("query_documents")
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceRouter(Inference):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
# zero-shot tool definitions as input to the model
|
||||||
|
tools: Optional[List[ToolDefinition]] = list,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||||
|
print("Inference Router: chat_completion")
|
16
llama_stack/distribution/routers/routing_table.py
Normal file
16
llama_stack/distribution/routers/routing_table.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class RoutingTable:
|
||||||
|
def __init__(self, provider_routing_table: Dict[str, Any]):
|
||||||
|
self.provider_routing_table = provider_routing_table
|
||||||
|
|
||||||
|
def print(self):
|
||||||
|
print(f"ROUTING TABLE {self.provider_routing_table}")
|
|
@ -50,7 +50,10 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
from llama_stack.distribution.utils.dynamic import (
|
||||||
|
instantiate_provider,
|
||||||
|
instantiate_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterator_type(typ):
|
def is_async_iterator_type(typ):
|
||||||
|
@ -288,8 +291,8 @@ def snake_to_camel(snake_str):
|
||||||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls(
|
async def resolve_impls_with_routing(
|
||||||
provider_map: Dict[str, ProviderMapEntry],
|
stack_run_config: StackRunConfig,
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Does two things:
|
Does two things:
|
||||||
|
@ -297,33 +300,28 @@ async def resolve_impls(
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
"""
|
"""
|
||||||
all_providers = api_providers()
|
all_providers = api_providers()
|
||||||
|
|
||||||
specs = {}
|
specs = {}
|
||||||
for api_str, item in provider_map.items():
|
|
||||||
|
for api_str in stack_run_config.apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
providers = all_providers[api]
|
providers = all_providers[api]
|
||||||
|
|
||||||
if isinstance(item, GenericProviderConfig):
|
# check for regular providers without routing
|
||||||
if item.provider_id not in providers:
|
if api_str in stack_run_config.provider_map:
|
||||||
|
provider_map_entry = stack_run_config.provider_map[api_str]
|
||||||
|
if provider_map_entry.provider_id not in providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
specs[api] = providers[item.provider_id]
|
specs[api] = providers[provider_map_entry.provider_id]
|
||||||
else:
|
|
||||||
assert isinstance(item, list)
|
|
||||||
inner_specs = []
|
|
||||||
for rt_entry in item:
|
|
||||||
if rt_entry.provider_id not in providers:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
|
||||||
)
|
|
||||||
inner_specs.append(providers[rt_entry.provider_id])
|
|
||||||
|
|
||||||
|
# check for routing table, we need to pass routing table to the router implementation
|
||||||
|
if api_str in stack_run_config.provider_routing_table:
|
||||||
specs[api] = RouterProviderSpec(
|
specs[api] = RouterProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
module=f"llama_stack.distribution.routers",
|
||||||
api_dependencies=[],
|
api_dependencies=[],
|
||||||
inner_specs=inner_specs,
|
routing_table=stack_run_config.provider_routing_table[api_str],
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
sorted_specs = topological_sort(specs.values())
|
||||||
|
@ -331,9 +329,16 @@ async def resolve_impls(
|
||||||
impls = {}
|
impls = {}
|
||||||
for spec in sorted_specs:
|
for spec in sorted_specs:
|
||||||
api = spec.api
|
api = spec.api
|
||||||
|
|
||||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||||
impl = await instantiate_provider(spec, deps, provider_map[api.value])
|
if api.value in stack_run_config.provider_map:
|
||||||
|
provider_config = stack_run_config.provider_map[api.value]
|
||||||
|
impl = await instantiate_provider(spec, deps, provider_config)
|
||||||
|
elif api.value in stack_run_config.provider_routing_table:
|
||||||
|
impl = await instantiate_router(
|
||||||
|
spec, api.value, stack_run_config.provider_routing_table
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot find provider_config for Api {api.value}")
|
||||||
impls[api] = impl
|
impls[api] = impl
|
||||||
|
|
||||||
return impls, specs
|
return impls, specs
|
||||||
|
@ -345,7 +350,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||||
|
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,19 @@ def instantiate_class_type(fully_qualified_name):
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def instantiate_router(
|
||||||
|
provider_spec: RouterProviderSpec,
|
||||||
|
api: str,
|
||||||
|
provider_routing_table: Dict[str, Any],
|
||||||
|
):
|
||||||
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
|
fn = getattr(module, "get_router_impl")
|
||||||
|
impl = await fn(api, provider_routing_table)
|
||||||
|
impl.__provider_spec__ = provider_spec
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
async def instantiate_provider(
|
async def instantiate_provider(
|
||||||
provider_spec: ProviderSpec,
|
provider_spec: ProviderSpec,
|
||||||
|
@ -35,22 +48,6 @@ async def instantiate_provider(
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**provider_config.config)
|
config = config_type(**provider_config.config)
|
||||||
args = [config, deps]
|
args = [config, deps]
|
||||||
elif isinstance(provider_spec, RouterProviderSpec):
|
|
||||||
method = "get_router_impl"
|
|
||||||
|
|
||||||
assert isinstance(provider_config, list)
|
|
||||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
|
||||||
inner_impls = []
|
|
||||||
for routing_entry in provider_config:
|
|
||||||
impl = await instantiate_provider(
|
|
||||||
inner_specs[routing_entry.provider_id],
|
|
||||||
deps,
|
|
||||||
routing_entry,
|
|
||||||
)
|
|
||||||
inner_impls.append((routing_entry.routing_key, impl))
|
|
||||||
|
|
||||||
config = None
|
|
||||||
args = [inner_impls, deps]
|
|
||||||
else:
|
else:
|
||||||
method = "get_provider_impl"
|
method = "get_provider_impl"
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ conda_env: local
|
||||||
apis_to_serve:
|
apis_to_serve:
|
||||||
- inference
|
- inference
|
||||||
# - memory
|
# - memory
|
||||||
# - telemetry
|
- telemetry
|
||||||
provider_map:
|
provider_map:
|
||||||
telemetry:
|
telemetry:
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
|
@ -36,60 +36,3 @@ provider_routing_table:
|
||||||
- routing_key: vector
|
- routing_key: vector
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# safety:
|
|
||||||
# provider_id: meta-reference
|
|
||||||
# config:
|
|
||||||
# llama_guard_shield:
|
|
||||||
# model: Llama-Guard-3-8B
|
|
||||||
# excluded_categories: []
|
|
||||||
# disable_input_check: false
|
|
||||||
# disable_output_check: false
|
|
||||||
# prompt_guard_shield:
|
|
||||||
# model: Prompt-Guard-86M
|
|
||||||
# telemetry:
|
|
||||||
# provider_id: meta-reference
|
|
||||||
# config: {}
|
|
||||||
# agents:
|
|
||||||
# provider_id: meta-reference
|
|
||||||
# config: {}
|
|
||||||
# memory:
|
|
||||||
# provider_id: meta-reference
|
|
||||||
# config: {}
|
|
||||||
# models:
|
|
||||||
# provider_id: builtin
|
|
||||||
# config:
|
|
||||||
# models_config:
|
|
||||||
# - core_model_id: Meta-Llama3.1-8B-Instruct
|
|
||||||
# provider_id: meta-reference
|
|
||||||
# api: inference
|
|
||||||
# config:
|
|
||||||
# model: Meta-Llama3.1-8B-Instruct
|
|
||||||
# quantization: null
|
|
||||||
# torch_seed: null
|
|
||||||
# max_seq_len: 4096
|
|
||||||
# max_batch_size: 1
|
|
||||||
# - core_model_id: Meta-Llama3.1-8B
|
|
||||||
# provider_id: meta-reference
|
|
||||||
# api: inference
|
|
||||||
# config:
|
|
||||||
# model: Meta-Llama3.1-8B
|
|
||||||
# quantization: null
|
|
||||||
# torch_seed: null
|
|
||||||
# max_seq_len: 4096
|
|
||||||
# max_batch_size: 1
|
|
||||||
# - core_model_id: Llama-Guard-3-8B
|
|
||||||
# provider_id: meta-reference
|
|
||||||
# api: safety
|
|
||||||
# config:
|
|
||||||
# model: Llama-Guard-3-8B
|
|
||||||
# excluded_categories: []
|
|
||||||
# disable_input_check: false
|
|
||||||
# disable_output_check: false
|
|
||||||
# - core_model_id: Prompt-Guard-86M
|
|
||||||
# provider_id: meta-reference
|
|
||||||
# api: safety
|
|
||||||
# config:
|
|
||||||
# model: Prompt-Guard-86M
|
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
|
@ -1,17 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, List, Tuple
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
|
||||||
|
|
||||||
|
|
||||||
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
|
|
||||||
from .memory import MemoryRouterImpl
|
|
||||||
|
|
||||||
impl = MemoryRouterImpl(inner_impls, deps)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
|
@ -1,91 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryRouterImpl(Memory):
|
|
||||||
"""Routes to an provider based on the memory bank type"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
inner_impls: List[Tuple[str, Any]],
|
|
||||||
deps: List[Api],
|
|
||||||
) -> None:
|
|
||||||
self.deps = deps
|
|
||||||
|
|
||||||
bank_types = [v.value for v in MemoryBankType]
|
|
||||||
|
|
||||||
self.providers = {}
|
|
||||||
for routing_key, provider_impl in inner_impls:
|
|
||||||
if routing_key not in bank_types:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown routing key `{routing_key}` for memory bank type"
|
|
||||||
)
|
|
||||||
self.providers[routing_key] = provider_impl
|
|
||||||
|
|
||||||
self.bank_id_to_type = {}
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
for p in self.providers.values():
|
|
||||||
await p.shutdown()
|
|
||||||
|
|
||||||
def get_provider(self, bank_type):
|
|
||||||
if bank_type not in self.providers:
|
|
||||||
raise ValueError(f"Memory bank type {bank_type} not supported")
|
|
||||||
|
|
||||||
return self.providers[bank_type]
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
config: MemoryBankConfig,
|
|
||||||
url: Optional[URL] = None,
|
|
||||||
) -> MemoryBank:
|
|
||||||
provider = self.get_provider(config.type)
|
|
||||||
bank = await provider.create_memory_bank(name, config, url)
|
|
||||||
self.bank_id_to_type[bank.bank_id] = config.type
|
|
||||||
return bank
|
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
|
||||||
bank_type = self.bank_id_to_type.get(bank_id)
|
|
||||||
if not bank_type:
|
|
||||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
|
||||||
|
|
||||||
provider = self.get_provider(bank_type)
|
|
||||||
return await provider.get_memory_bank(bank_id)
|
|
||||||
|
|
||||||
async def insert_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
documents: List[MemoryBankDocument],
|
|
||||||
ttl_seconds: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
bank_type = self.bank_id_to_type.get(bank_id)
|
|
||||||
if not bank_type:
|
|
||||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
|
||||||
|
|
||||||
provider = self.get_provider(bank_type)
|
|
||||||
return await provider.insert_documents(bank_id, documents, ttl_seconds)
|
|
||||||
|
|
||||||
async def query_documents(
|
|
||||||
self,
|
|
||||||
bank_id: str,
|
|
||||||
query: InterleavedTextMedia,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> QueryDocumentsResponse:
|
|
||||||
bank_type = self.bank_id_to_type.get(bank_id)
|
|
||||||
if not bank_type:
|
|
||||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
|
||||||
|
|
||||||
provider = self.get_provider(bank_type)
|
|
||||||
return await provider.query_documents(bank_id, query, params)
|
|
Loading…
Add table
Add a link
Reference in a new issue