fix case where memory bank is registered without provider_id

This commit is contained in:
Xi Yan 2024-10-17 16:17:46 -07:00
parent 9fcf5d58e0
commit f0600a30c9
2 changed files with 23 additions and 2 deletions

View file

@ -92,6 +92,21 @@ async def run_main(host: str, port: int, stream: bool):
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")
# register memory bank for the first time
response = await client.register_memory_bank(
VectorMemoryBankDef(
identifier="test_bank2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
)
cprint(f"register_memory_bank response={response}", "blue")
# list again after registering
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))

View file

@ -110,10 +110,16 @@ class CommonRoutingTableImpl(RoutingTable):
async def register_object(self, obj: RoutableObjectWithProvider):
entries = self.registry.get(obj.identifier, [])
for entry in entries:
if entry.provider_id == obj.provider_id:
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
if entry.provider_id == obj.provider_id or not obj.provider_id:
print(
f"`{obj.identifier}` already registered with `{entry.provider_id}`"
)
return
# if provider_id is not specified, we'll pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")