Kill get_routing_keys(), rename register_ to validate_

This commit is contained in:
Ashwin Bharambe 2024-09-30 17:08:56 -07:00
parent aab81cd5ad
commit 1702aa5e3f
11 changed files with 20 additions and 48 deletions

View file

@ -44,7 +44,7 @@ class CommonRoutingTableImpl(RoutingTable):
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
continue
await p.register_routing_keys(keys)
await p.validate_routing_keys(keys)
async def shutdown(self) -> None:
for _, p in self.unique_providers:

View file

@ -34,13 +34,10 @@ class _HfAdapter(Inference, RoutableProvider):
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def register_routing_keys(self, routing_keys: list[str]) -> None:
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
self.routing_keys = routing_keys
def get_routing_keys(self) -> list[str]:
return self.routing_keys
pass
async def shutdown(self) -> None:
pass

View file

@ -93,12 +93,9 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
pass
async def create_memory_bank(
self,

View file

@ -161,12 +161,9 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
pass
async def create_memory_bank(
self,

View file

@ -45,16 +45,11 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:

View file

@ -35,14 +35,10 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SAFETY_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None

View file

@ -53,9 +53,13 @@ class RoutingTable(Protocol):
class RoutableProvider(Protocol):
async def register_routing_keys(self, keys: List[str]) -> None: ...
"""
A provider which sits behind the RoutingTable and can get routed to.
def get_routing_keys(self) -> List[str]: ...
All Inference / Safety / Memory providers fall into this bucket.
"""
async def validate_routing_keys(self, keys: List[str]) -> None: ...
class GenericProviderConfig(BaseModel):

View file

@ -38,15 +38,12 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
async def register_routing_keys(self, routing_keys: List[str]) -> None:
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
assert (
len(routing_keys) == 1
), f"Only one routing key is supported {routing_keys}"
assert routing_keys[0] == self.config.model
def get_routing_keys(self) -> List[str]:
return [self.config.model]
async def shutdown(self) -> None:
self.generator.stop()

View file

@ -72,12 +72,9 @@ class FaissMemoryImpl(Memory, RoutableProvider):
async def shutdown(self) -> None: ...
async def register_routing_keys(self, routing_keys: List[str]) -> None:
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
pass
async def create_memory_bank(
self,

View file

@ -51,15 +51,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
async def shutdown(self) -> None:
pass
async def register_routing_keys(self, routing_keys: List[str]) -> None:
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
available_shields = [v.value for v in MetaReferenceShieldType]
for key in routing_keys:
if key not in available_shields:
raise ValueError(f"Unknown safety shield type: {key}")
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
async def run_shield(
self,

View file

@ -16,16 +16,12 @@ class RoutableProviderForModels(RoutableProvider):
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
async def register_routing_keys(self, routing_keys: List[str]):
async def validate_routing_keys(self, routing_keys: List[str]):
for routing_key in routing_keys:
if routing_key not in self.stack_to_provider_models_map:
raise ValueError(
f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}"
)
self.routing_keys = routing_keys
def get_routing_keys(self) -> List[str]:
return self.routing_keys
def map_to_provider_model(self, routing_key: str) -> str:
model = resolve_model(routing_key)