nuke updates

This commit is contained in:
Dinesh Yeduguru 2024-11-14 17:05:09 -08:00
parent 690e525a36
commit aa93eeb2b7
15 changed files with 15 additions and 429 deletions

View file

@ -51,18 +51,6 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
raise ValueError(f"Unknown API {api} for registering object with provider")
async def update_object_with_provider(
obj: RoutableObject, p: Any
) -> Optional[RoutableObject]:
api = get_impl_api(p)
if api == Api.memory:
return await p.update_memory_bank(obj)
elif api == Api.inference:
return await p.update_model(obj)
else:
raise ValueError(f"Update not supported for {api}")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.memory:
@ -176,14 +164,6 @@ class CommonRoutingTableImpl(RoutingTable):
obj, self.impls_by_provider_id[obj.provider_id]
)
async def update_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
registered_obj = await update_object_with_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
return await self.dist_registry.update(registered_obj or obj)
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
@ -256,27 +236,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
registered_model = await self.register_object(model)
return registered_model
async def update_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
updated_model = Model(
identifier=model_id,
provider_resource_id=provider_model_id
or existing_model.provider_resource_id,
provider_id=provider_id or existing_model.provider_id,
metadata=metadata or existing_model.metadata,
)
registered_model = await self.update_object(updated_model)
return registered_model
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
@ -357,31 +316,6 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
await self.register_object(memory_bank)
return memory_bank
async def update_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
existing_bank = await self.get_memory_bank(memory_bank_id)
if existing_bank is None:
raise ValueError(f"Memory bank {memory_bank_id} not found")
updated_bank = parse_obj_as(
MemoryBank,
{
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id or existing_bank.provider_id,
"provider_resource_id": provider_memory_bank_id
or existing_bank.provider_resource_id,
**params.model_dump(),
},
)
registered_bank = await self.update_object(updated_bank)
return registered_bank
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
existing_bank = await self.get_memory_bank(memory_bank_id)
if existing_bank is None: