From 9a57a009eeab69924ac2e0861f99052d327d99ba Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 20:32:47 -0800 Subject: [PATCH] Need to await for get_object_from_identifier() now --- llama_stack/distribution/routers/routing_tables.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 17bda0e70..1efd02c89 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -188,7 +188,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return await self.get_all_with_type("model") async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: - return self.get_object_by_identifier(identifier) + return await self.get_object_by_identifier(identifier) async def register_model(self, model: ModelDefWithProvider) -> None: await self.register_object(model) @@ -199,7 +199,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return await self.get_all_with_type("shield") async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: - return self.get_object_by_identifier(shield_type) + return await self.get_object_by_identifier(shield_type) async def register_shield(self, shield: ShieldDefWithProvider) -> None: await self.register_object(shield) @@ -212,7 +212,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def get_memory_bank( self, identifier: str ) -> Optional[MemoryBankDefWithProvider]: - return self.get_object_by_identifier(identifier) + return await self.get_object_by_identifier(identifier) async def register_memory_bank( self, memory_bank: MemoryBankDefWithProvider @@ -227,7 +227,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def get_dataset( self, dataset_identifier: str ) -> Optional[DatasetDefWithProvider]: - return self.get_object_by_identifier(dataset_identifier) + return await self.get_object_by_identifier(dataset_identifier) async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None: await self.register_object(dataset_def) @@ -240,7 +240,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): async def get_scoring_function( self, name: str ) -> Optional[ScoringFnDefWithProvider]: - return self.get_object_by_identifier(name) + return await self.get_object_by_identifier(name) async def register_scoring_function( self, function_def: ScoringFnDefWithProvider