mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
fix case where memory bank is registered without provider_id
This commit is contained in:
parent
9fcf5d58e0
commit
f0600a30c9
2 changed files with 23 additions and 2 deletions
|
@ -92,6 +92,21 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
response = await client.list_memory_banks()
|
response = await client.list_memory_banks()
|
||||||
cprint(f"list_memory_banks response={response}", "green")
|
cprint(f"list_memory_banks response={response}", "green")
|
||||||
|
|
||||||
|
# register memory bank for the first time
|
||||||
|
response = await client.register_memory_bank(
|
||||||
|
VectorMemoryBankDef(
|
||||||
|
identifier="test_bank2",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cprint(f"register_memory_bank response={response}", "blue")
|
||||||
|
|
||||||
|
# list again after registering
|
||||||
|
response = await client.list_memory_banks()
|
||||||
|
cprint(f"list_memory_banks response={response}", "green")
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int, stream: bool = True):
|
def main(host: str, port: int, stream: bool = True):
|
||||||
asyncio.run(run_main(host, port, stream))
|
asyncio.run(run_main(host, port, stream))
|
||||||
|
|
|
@ -110,10 +110,16 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider):
|
async def register_object(self, obj: RoutableObjectWithProvider):
|
||||||
entries = self.registry.get(obj.identifier, [])
|
entries = self.registry.get(obj.identifier, [])
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if entry.provider_id == obj.provider_id:
|
if entry.provider_id == obj.provider_id or not obj.provider_id:
|
||||||
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
|
print(
|
||||||
|
f"`{obj.identifier}` already registered with `{entry.provider_id}`"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# if provider_id is not specified, we'll pick an arbitrary one from existing entries
|
||||||
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
if obj.provider_id not in self.impls_by_provider_id:
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue