mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
[api_updates_3] fix CLI for routing_table, bug fixes for memory & safety (#90)
* fix llama stack build * fix configure * fix configure for simple case * configure w/ routing * move examples config * fix memory router naming * issue w/ safety * fix config w/ safety * update memory endpoints * allow providers in api_providers * configure script works * all endpoints w/ build->configure->run simple local works * new example run.yaml * run openapi generator
This commit is contained in:
parent
8cf634e615
commit
ddebf9b6e7
18 changed files with 725 additions and 605 deletions
|
@ -20,7 +20,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: RoutingTableConfig,
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
) -> None:
|
||||
self.providers = {k: v for k, v in inner_impls}
|
||||
self.routing_keys = list(self.providers.keys())
|
||||
|
@ -40,7 +40,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return self.routing_keys
|
||||
|
||||
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == routing_key:
|
||||
return entry
|
||||
return None
|
||||
|
@ -50,7 +50,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
async def list_models(self) -> List[ModelServingSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config.entries:
|
||||
for entry in self.routing_table_config:
|
||||
model_id = entry.routing_key
|
||||
specs.append(
|
||||
ModelServingSpec(
|
||||
|
@ -61,7 +61,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return specs
|
||||
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == core_model_id:
|
||||
return ModelServingSpec(
|
||||
llama_model=resolve_model(core_model_id),
|
||||
|
@ -74,7 +74,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
|
||||
async def list_shields(self) -> List[ShieldSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config.entries:
|
||||
for entry in self.routing_table_config:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
|
@ -84,7 +84,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
return specs
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == shield_type:
|
||||
return ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
|
@ -97,7 +97,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
|
||||
async def list_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config.entries:
|
||||
for entry in self.routing_table_config:
|
||||
specs.append(
|
||||
MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
|
@ -107,7 +107,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
return specs
|
||||
|
||||
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == bank_type:
|
||||
return MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue