diff --git a/llama_stack/apis/memory_banks/client.py b/llama_stack/apis/memory_banks/client.py index 588a93fe2..69be35d02 100644 --- a/llama_stack/apis/memory_banks/client.py +++ b/llama_stack/apis/memory_banks/client.py @@ -92,6 +92,21 @@ async def run_main(host: str, port: int, stream: bool): response = await client.list_memory_banks() cprint(f"list_memory_banks response={response}", "green") + # register memory bank for the first time + response = await client.register_memory_bank( + VectorMemoryBankDef( + identifier="test_bank2", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + ) + cprint(f"register_memory_bank response={response}", "blue") + + # list again after registering + response = await client.list_memory_banks() + cprint(f"list_memory_banks response={response}", "green") + def main(host: str, port: int, stream: bool = True): asyncio.run(run_main(host, port, stream)) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 17755f0e4..ede30aea1 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -110,10 +110,16 @@ class CommonRoutingTableImpl(RoutingTable): async def register_object(self, obj: RoutableObjectWithProvider): entries = self.registry.get(obj.identifier, []) for entry in entries: - if entry.provider_id == obj.provider_id: - print(f"`{obj.identifier}` already registered with `{obj.provider_id}`") + if entry.provider_id == obj.provider_id or not obj.provider_id: + print( + f"`{obj.identifier}` already registered with `{entry.provider_id}`" + ) return + # if provider_id is not specified, we'll pick an arbitrary one from existing entries + if not obj.provider_id and len(self.impls_by_provider_id) > 0: + obj.provider_id = list(self.impls_by_provider_id.keys())[0] + if obj.provider_id not in self.impls_by_provider_id: raise ValueError(f"Provider `{obj.provider_id}` not found")