mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Fix for the fix!
This commit is contained in:
parent
fb2678b134
commit
0763a0b85f
1 changed files with 8 additions and 7 deletions
|
@ -178,13 +178,14 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
await register_object_with_provider(obj, p)
|
await register_object_with_provider(obj, p)
|
||||||
await self.dist_registry.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||||
return await self.dist_registry.get_all()
|
objs = await self.dist_registry.get_all()
|
||||||
|
return [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||||
return await self.get_all()
|
return await self.get_all_with_type("model")
|
||||||
|
|
||||||
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
|
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
|
||||||
return self.get_object_by_identifier(identifier)
|
return self.get_object_by_identifier(identifier)
|
||||||
|
@ -195,7 +196,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
return await self.get_all()
|
return await self.get_all_with_type("shield")
|
||||||
|
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||||
return self.get_object_by_identifier(shield_type)
|
return self.get_object_by_identifier(shield_type)
|
||||||
|
@ -206,7 +207,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
||||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||||
return await self.get_all()
|
return await self.get_all_with_type("memory_bank")
|
||||||
|
|
||||||
async def get_memory_bank(
|
async def get_memory_bank(
|
||||||
self, identifier: str
|
self, identifier: str
|
||||||
|
@ -221,7 +222,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
||||||
return await self.get_all()
|
return await self.get_all_with_type("dataset")
|
||||||
|
|
||||||
async def get_dataset(
|
async def get_dataset(
|
||||||
self, dataset_identifier: str
|
self, dataset_identifier: str
|
||||||
|
@ -234,7 +235,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
||||||
return await self.get_all()
|
return await self.get_all_with_type("scoring_function")
|
||||||
|
|
||||||
async def get_scoring_function(
|
async def get_scoring_function(
|
||||||
self, name: str
|
self, name: str
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue