Kill get_routing_keys(), rename register_ to validate_

This commit is contained in:
Ashwin Bharambe 2024-09-30 17:08:56 -07:00
parent aab81cd5ad
commit 1702aa5e3f
11 changed files with 20 additions and 48 deletions

View file

@ -44,7 +44,7 @@ class CommonRoutingTableImpl(RoutingTable):
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None: if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
continue continue
await p.register_routing_keys(keys) await p.validate_routing_keys(keys)
async def shutdown(self) -> None: async def shutdown(self) -> None:
for _, p in self.unique_providers: for _, p in self.unique_providers:

View file

@ -34,13 +34,10 @@ class _HfAdapter(Inference, RoutableProvider):
self.tokenizer = Tokenizer.get_instance() self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer) self.formatter = ChatFormat(self.tokenizer)
async def register_routing_keys(self, routing_keys: list[str]) -> None: async def validate_routing_keys(self, routing_keys: list[str]) -> None:
# these are the model names the Llama Stack will use to route requests to this provider # these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary # perform validation here if necessary
self.routing_keys = routing_keys pass
def get_routing_keys(self) -> list[str]:
return self.routing_keys
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass

View file

@ -93,12 +93,9 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_routing_keys(self, routing_keys: List[str]) -> None: async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[chroma] Registering memory bank routing keys: {routing_keys}") print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys pass
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def create_memory_bank( async def create_memory_bank(
self, self,

View file

@ -161,12 +161,9 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_routing_keys(self, routing_keys: List[str]) -> None: async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}") print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys pass
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def create_memory_bank( async def create_memory_bank(
self, self,

View file

@ -45,16 +45,11 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_routing_keys(self, routing_keys: List[str]) -> None: async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys: for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES: if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}") raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:

View file

@ -35,14 +35,10 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_routing_keys(self, routing_keys: List[str]) -> None: async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys: for key in routing_keys:
if key not in SAFETY_SHIELD_TYPES: if key not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}") raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None

View file

@ -53,9 +53,13 @@ class RoutingTable(Protocol):
class RoutableProvider(Protocol): class RoutableProvider(Protocol):
async def register_routing_keys(self, keys: List[str]) -> None: ... """
A provider which sits behind the RoutingTable and can get routed to.
def get_routing_keys(self) -> List[str]: ... All Inference / Safety / Memory providers fall into this bucket.
"""
async def validate_routing_keys(self, keys: List[str]) -> None: ...
class GenericProviderConfig(BaseModel): class GenericProviderConfig(BaseModel):

View file

@ -38,15 +38,12 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start() self.generator.start()
async def register_routing_keys(self, routing_keys: List[str]) -> None: async def validate_routing_keys(self, routing_keys: List[str]) -> None:
assert ( assert (
len(routing_keys) == 1 len(routing_keys) == 1
), f"Only one routing key is supported {routing_keys}" ), f"Only one routing key is supported {routing_keys}"
assert routing_keys[0] == self.config.model assert routing_keys[0] == self.config.model
def get_routing_keys(self) -> List[str]:
return [self.config.model]
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.generator.stop() self.generator.stop()

View file

@ -72,12 +72,9 @@ class FaissMemoryImpl(Memory, RoutableProvider):
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def register_routing_keys(self, routing_keys: List[str]) -> None: async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[faiss] Registering memory bank routing keys: {routing_keys}") print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys pass
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def create_memory_bank( async def create_memory_bank(
self, self,

View file

@ -51,15 +51,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_routing_keys(self, routing_keys: List[str]) -> None: async def validate_routing_keys(self, routing_keys: List[str]) -> None:
available_shields = [v.value for v in MetaReferenceShieldType] available_shields = [v.value for v in MetaReferenceShieldType]
for key in routing_keys: for key in routing_keys:
if key not in available_shields: if key not in available_shields:
raise ValueError(f"Unknown safety shield type: {key}") raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield( async def run_shield(
self, self,

View file

@ -16,16 +16,12 @@ class RoutableProviderForModels(RoutableProvider):
def __init__(self, stack_to_provider_models_map: Dict[str, str]): def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map self.stack_to_provider_models_map = stack_to_provider_models_map
async def register_routing_keys(self, routing_keys: List[str]): async def validate_routing_keys(self, routing_keys: List[str]):
for routing_key in routing_keys: for routing_key in routing_keys:
if routing_key not in self.stack_to_provider_models_map: if routing_key not in self.stack_to_provider_models_map:
raise ValueError( raise ValueError(
f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}" f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}"
) )
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
def map_to_provider_model(self, routing_key: str) -> str: def map_to_provider_model(self, routing_key: str) -> str:
model = resolve_model(routing_key) model = resolve_model(routing_key)