issue w/ safety

This commit is contained in:
Xi Yan 2024-09-22 23:15:34 -07:00
parent e0ad4fb99c
commit 4586692dee
8 changed files with 50 additions and 66 deletions

View file

@ -20,7 +20,7 @@ class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
routing_table_config: Dict[str, List[RoutableProviderConfig]],
) -> None:
self.providers = {k: v for k, v in inner_impls}
self.routing_keys = list(self.providers.keys())
@ -40,7 +40,7 @@ class CommonRoutingTableImpl(RoutingTable):
return self.routing_keys
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == routing_key:
return entry
return None
@ -50,7 +50,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelServingSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
@ -61,7 +61,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return specs
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
@ -74,7 +74,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
@ -84,7 +84,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
@ -97,7 +97,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
@ -107,7 +107,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return specs
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,