[api_updates_3] fix CLI for routing_table, bug fixes for memory & safety (#90)

* fix llama stack build

* fix configure

* fix configure for simple case

* configure w/ routing

* move examples config

* fix memory router naming

* issue w/ safety

* fix config w/ safety

* update memory endpoints

* allow providers in api_providers

* configure script works

* all endpoints w/ build->configure->run simple local works

* new example run.yaml

* run openapi generator
This commit is contained in:
Xi Yan 2024-09-23 08:46:33 -07:00 committed by GitHub
parent 8cf634e615
commit ddebf9b6e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 725 additions and 605 deletions

View file

@ -12,7 +12,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
async def get_routing_table_impl(
api: Api,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
routing_table_config: Dict[str, List[RoutableProviderConfig]],
_deps,
) -> Any:
from .routing_tables import (

View file

@ -46,9 +46,9 @@ class MemoryRouter(Memory):
url: Optional[URL] = None,
) -> MemoryBank:
bank_type = config.type
provider = await self.routing_table.get_provider_impl(
bank_type
).create_memory_bank(name, config, url)
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
name, config, url
)
self.bank_id_to_type[bank.bank_id] = bank_type
return bank
@ -162,6 +162,7 @@ class SafetyRouter(Safety):
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
print(f"Running shield {shield_type}")
return await self.routing_table.get_provider_impl(shield_type).run_shield(
shield_type=shield_type,
messages=messages,

View file

@ -20,7 +20,7 @@ class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
routing_table_config: Dict[str, List[RoutableProviderConfig]],
) -> None:
self.providers = {k: v for k, v in inner_impls}
self.routing_keys = list(self.providers.keys())
@ -40,7 +40,7 @@ class CommonRoutingTableImpl(RoutingTable):
return self.routing_keys
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == routing_key:
return entry
return None
@ -50,7 +50,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelServingSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
@ -61,7 +61,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return specs
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
@ -74,7 +74,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
@ -84,7 +84,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
@ -97,7 +97,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
@ -107,7 +107,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return specs
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,