mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 13:02:38 +00:00
dataset datasetio
This commit is contained in:
parent
e8de70fdbe
commit
f8d9e4f60f
8 changed files with 249 additions and 10 deletions
|
|
@ -28,6 +28,10 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
|||
await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
elif api == Api.datasetio:
|
||||
await p.register_dataset(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||
|
|
@ -81,6 +85,16 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
add_objects(memory_banks)
|
||||
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
datasets = await p.list_datasets()
|
||||
|
||||
# do in-memory updates due to pesky Annotated unions
|
||||
for d in datasets:
|
||||
d.provider_id = pid
|
||||
|
||||
add_objects(datasets)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
|
@ -138,6 +152,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
await register_object_with_provider(obj, p)
|
||||
|
||||
if obj.identifier not in self.registry:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue