Need to await for get_object_from_identifier() now

This commit is contained in:
Ashwin Bharambe 2024-11-04 20:32:47 -08:00
parent 7cf4c905f3
commit 9a57a009ee

View file

@ -188,7 +188,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return await self.get_all_with_type("model")
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
return self.get_object_by_identifier(identifier)
return await self.get_object_by_identifier(identifier)
async def register_model(self, model: ModelDefWithProvider) -> None:
await self.register_object(model)
@ -199,7 +199,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return await self.get_all_with_type("shield")
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
return self.get_object_by_identifier(shield_type)
return await self.get_object_by_identifier(shield_type)
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
await self.register_object(shield)
@ -212,7 +212,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def get_memory_bank(
self, identifier: str
) -> Optional[MemoryBankDefWithProvider]:
return self.get_object_by_identifier(identifier)
return await self.get_object_by_identifier(identifier)
async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider
@ -227,7 +227,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def get_dataset(
self, dataset_identifier: str
) -> Optional[DatasetDefWithProvider]:
return self.get_object_by_identifier(dataset_identifier)
return await self.get_object_by_identifier(dataset_identifier)
async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
await self.register_object(dataset_def)
@ -240,7 +240,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
async def get_scoring_function(
self, name: str
) -> Optional[ScoringFnDefWithProvider]:
return self.get_object_by_identifier(name)
return await self.get_object_by_identifier(name)
async def register_scoring_function(
self, function_def: ScoringFnDefWithProvider