From 0763a0b85fa77ee8798635fe450435f67dfc42a0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 20:06:01 -0800 Subject: [PATCH] Fix for the fix! --- .../distribution/routers/routing_tables.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index c184557c6..17bda0e70 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -178,13 +178,14 @@ class CommonRoutingTableImpl(RoutingTable): await register_object_with_provider(obj, p) await self.dist_registry.register(obj) - async def get_all(self) -> List[RoutableObjectWithProvider]: - return await self.dist_registry.get_all() + async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]: + objs = await self.dist_registry.get_all() + return [obj for obj in objs if obj.type == type] class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[ModelDefWithProvider]: - return await self.get_all() + return await self.get_all_with_type("model") async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: return self.get_object_by_identifier(identifier) @@ -195,7 +196,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldDef]: - return await self.get_all() + return await self.get_all_with_type("shield") async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: return self.get_object_by_identifier(shield_type) @@ -206,7 +207,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: - return await self.get_all() + return await self.get_all_with_type("memory_bank") async def get_memory_bank( self, identifier: str @@ -221,7 +222,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> List[DatasetDefWithProvider]: - return await self.get_all() + return await self.get_all_with_type("dataset") async def get_dataset( self, dataset_identifier: str @@ -234,7 +235,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: - return await self.get_all() + return await self.get_all_with_type("scoring_function") async def get_scoring_function( self, name: str