Fix shield_type and routing table breakage

This commit is contained in:
Ashwin Bharambe 2024-11-04 19:40:04 -08:00
parent 657de08f04
commit fb2678b134
6 changed files with 30 additions and 35 deletions

View file

@ -178,13 +178,13 @@ class CommonRoutingTableImpl(RoutingTable):
await register_object_with_provider(obj, p)
await self.dist_registry.register(obj)
async def get_all(self) -> List[RoutableObjectWithProvider]:
return await self.dist_registry.get_all()
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
return await self.get_all()
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
return self.get_object_by_identifier(identifier)
@ -195,10 +195,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldDef]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
return await self.get_all()
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
return self.get_object_by_identifier(shield_type)
@ -209,10 +206,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
return await self.get_all()
async def get_memory_bank(
self, identifier: str
@ -227,10 +221,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> List[DatasetDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
return await self.get_all()
async def get_dataset(
self, dataset_identifier: str
@ -243,10 +234,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
return await self.get_all()
async def get_scoring_function(
self, name: str