From e0ad4fb99c3cc9e4847632ad38e1c858f53cc61c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 22 Sep 2024 22:30:53 -0700 Subject: [PATCH] fix memory router naming --- llama_stack/distribution/configure.py | 12 ++-- llama_stack/distribution/routers/routers.py | 6 +- tests/examples/router-local-run.yaml | 75 ++++++++++++--------- 3 files changed, 52 insertions(+), 41 deletions(-) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 4c372056b..361c24416 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -81,16 +81,18 @@ def configure_api_providers( except Exception: existing = None cfg = prompt_for_config(config_type, existing) - config.api_providers[api_str] = GenericProviderConfig( - provider_id=p, - config=cfg.dict(), - ) if api_str in router_api2builtin_api: - # a routing api, we need to assign it a routing_key and put it in the routing_table + # a routing api, we need to infer and assign it a routing_key and put it in the routing_table routing_key = prompt( "> Enter routing key for the {} provider: ".format(api_str), ) + config.routing_table[] + else: + config.api_providers[api_str] = GenericProviderConfig( + provider_id=p, + config=cfg.dict(), + ) print("") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 6d296d20e..c9a536aa0 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -46,9 +46,9 @@ class MemoryRouter(Memory): url: Optional[URL] = None, ) -> MemoryBank: bank_type = config.type - provider = await self.routing_table.get_provider_impl( - bank_type - ).create_memory_bank(name, config, url) + bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank( + name, config, url + ) self.bank_id_to_type[bank.bank_id] = bank_type return bank diff --git a/tests/examples/router-local-run.yaml b/tests/examples/router-local-run.yaml index 9dec673ea..807dcafec 100644 --- a/tests/examples/router-local-run.yaml +++ b/tests/examples/router-local-run.yaml @@ -13,38 +13,47 @@ api_providers: telemetry: provider_id: meta-reference config: {} - safety: - provider_id: meta-reference - config: - llama_guard_shield: - model: Llama-Guard-3-8B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M agents: provider_id: meta-reference - config: {} -provider_routing_table: - inference: - - routing_key: Meta-Llama3.1-8B-Instruct - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B-Instruct - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - - routing_key: Meta-Llama3.1-8B - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - memory: - - routing_key: vector - provider_id: meta-reference - config: {} + config: + persistence_store: + namespace: null + type: sqlite + db_path: /home/xiyan/.llama/runtime/kvstore.db +routing_tables: + models: + entries: + - routing_key: Meta-Llama3.1-8B-Instruct + provider_id: meta-reference + config: + model: Meta-Llama3.1-8B-Instruct + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + - routing_key: Meta-Llama3.1-8B + provider_id: meta-reference + config: + model: Meta-Llama3.1-8B + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + memory_banks: + entries: + - routing_key: vector + provider_id: meta-reference + config: {} + shields: + entries: + - routing_key: llama_guard_shield + provider_id: meta-reference + config: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + - routing_key: prompt_guard_shield + provider_id: meta-reference + config: + model: Prompt-Guard-86M