[api_updates_3] fix CLI for routing_table, bug fixes for memory & safety (#90)

* fix llama stack build

* fix configure

* fix configure for simple case

* configure w/ routing

* move examples config

* fix memory router naming

* issue w/ safety

* fix config w/ safety

* update memory endpoints

* allow providers in api_providers

* configure script works

* all endpoints w/ build->configure->run simple local works

* new example run.yaml

* run openapi generator
This commit is contained in:
Xi Yan 2024-09-23 08:46:33 -07:00 committed by GitHub
parent 8cf634e615
commit ddebf9b6e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 725 additions and 605 deletions

View file

@ -20,7 +20,7 @@ class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
routing_table_config: Dict[str, List[RoutableProviderConfig]],
) -> None:
self.providers = {k: v for k, v in inner_impls}
self.routing_keys = list(self.providers.keys())
@ -40,7 +40,7 @@ class CommonRoutingTableImpl(RoutingTable):
return self.routing_keys
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == routing_key:
return entry
return None
@ -50,7 +50,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelServingSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
@ -61,7 +61,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return specs
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
@ -74,7 +74,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
@ -84,7 +84,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
@ -97,7 +97,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
@ -107,7 +107,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return specs
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,