diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 549b1866c..02dc942e8 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -44,7 +44,7 @@ class CommonRoutingTableImpl(RoutingTable): if isinstance(spec, RemoteProviderSpec) and spec.adapter is None: continue - await p.register_routing_keys(keys) + await p.validate_routing_keys(keys) async def shutdown(self) -> None: for _, p in self.unique_providers: diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 67cfa21b5..a5e5a99be 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -34,13 +34,10 @@ class _HfAdapter(Inference, RoutableProvider): self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) - async def register_routing_keys(self, routing_keys: list[str]) -> None: + async def validate_routing_keys(self, routing_keys: list[str]) -> None: # these are the model names the Llama Stack will use to route requests to this provider # perform validation here if necessary - self.routing_keys = routing_keys - - def get_routing_keys(self) -> list[str]: - return self.routing_keys + pass async def shutdown(self) -> None: pass diff --git a/llama_stack/providers/adapters/memory/chroma/chroma.py b/llama_stack/providers/adapters/memory/chroma/chroma.py index c741a61e3..afa13111f 100644 --- a/llama_stack/providers/adapters/memory/chroma/chroma.py +++ b/llama_stack/providers/adapters/memory/chroma/chroma.py @@ -93,12 +93,9 @@ class ChromaMemoryAdapter(Memory, RoutableProvider): async def shutdown(self) -> None: pass - async def register_routing_keys(self, routing_keys: List[str]) -> None: + async def validate_routing_keys(self, routing_keys: List[str]) -> None: print(f"[chroma] Registering memory bank routing keys: {routing_keys}") - self.routing_keys = routing_keys - - def get_routing_keys(self) -> List[str]: - return self.routing_keys + pass async def create_memory_bank( self, diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py index 5b57b166a..5864aa7dc 100644 --- a/llama_stack/providers/adapters/memory/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -161,12 +161,9 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider): async def shutdown(self) -> None: pass - async def register_routing_keys(self, routing_keys: List[str]) -> None: + async def validate_routing_keys(self, routing_keys: List[str]) -> None: print(f"[pgvector] Registering memory bank routing keys: {routing_keys}") - self.routing_keys = routing_keys - - def get_routing_keys(self) -> List[str]: - return self.routing_keys + pass async def create_memory_bank( self, diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index d3eecc9c7..814704e2c 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -45,16 +45,11 @@ class BedrockSafetyAdapter(Safety, RoutableProvider): async def shutdown(self) -> None: pass - async def register_routing_keys(self, routing_keys: List[str]) -> None: + async def validate_routing_keys(self, routing_keys: List[str]) -> None: for key in routing_keys: if key not in SUPPORTED_SHIELD_TYPES: raise ValueError(f"Unknown safety shield type: {key}") - self.routing_keys = routing_keys - - def get_routing_keys(self) -> List[str]: - return self.routing_keys - async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index 06b16d23d..c7a667e01 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -35,14 +35,10 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider): async def shutdown(self) -> None: pass - async def register_routing_keys(self, routing_keys: List[str]) -> None: + async def validate_routing_keys(self, routing_keys: List[str]) -> None: for key in routing_keys: if key not in SAFETY_SHIELD_TYPES: raise ValueError(f"Unknown safety shield type: {key}") - self.routing_keys = routing_keys - - def get_routing_keys(self) -> List[str]: - return self.routing_keys async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 12f107c41..a9a3d86e9 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -53,9 +53,13 @@ class RoutingTable(Protocol): class RoutableProvider(Protocol): - async def register_routing_keys(self, keys: List[str]) -> None: ... + """ + A provider which sits behind the RoutingTable and can get routed to. - def get_routing_keys(self) -> List[str]: ... + All Inference / Safety / Memory providers fall into this bucket. + """ + + async def validate_routing_keys(self, keys: List[str]) -> None: ... class GenericProviderConfig(BaseModel): diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 5184b50f0..e89d8ec4c 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -38,15 +38,12 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider): self.generator = LlamaModelParallelGenerator(self.config) self.generator.start() - async def register_routing_keys(self, routing_keys: List[str]) -> None: + async def validate_routing_keys(self, routing_keys: List[str]) -> None: assert ( len(routing_keys) == 1 ), f"Only one routing key is supported {routing_keys}" assert routing_keys[0] == self.config.model - def get_routing_keys(self) -> List[str]: - return [self.config.model] - async def shutdown(self) -> None: self.generator.stop() diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index d79ef7b6f..b9a00908e 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -72,12 +72,9 @@ class FaissMemoryImpl(Memory, RoutableProvider): async def shutdown(self) -> None: ... - async def register_routing_keys(self, routing_keys: List[str]) -> None: + async def validate_routing_keys(self, routing_keys: List[str]) -> None: print(f"[faiss] Registering memory bank routing keys: {routing_keys}") - self.routing_keys = routing_keys - - def get_routing_keys(self) -> List[str]: - return self.routing_keys + pass async def create_memory_bank( self, diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index a2ce69880..f02574f19 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -51,15 +51,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider): async def shutdown(self) -> None: pass - async def register_routing_keys(self, routing_keys: List[str]) -> None: + async def validate_routing_keys(self, routing_keys: List[str]) -> None: available_shields = [v.value for v in MetaReferenceShieldType] for key in routing_keys: if key not in available_shields: raise ValueError(f"Unknown safety shield type: {key}") - self.routing_keys = routing_keys - - def get_routing_keys(self) -> List[str]: - return self.routing_keys async def run_shield( self, diff --git a/llama_stack/providers/utils/inference/routable.py b/llama_stack/providers/utils/inference/routable.py index 6dd2dd1fe..a36631208 100644 --- a/llama_stack/providers/utils/inference/routable.py +++ b/llama_stack/providers/utils/inference/routable.py @@ -16,16 +16,12 @@ class RoutableProviderForModels(RoutableProvider): def __init__(self, stack_to_provider_models_map: Dict[str, str]): self.stack_to_provider_models_map = stack_to_provider_models_map - async def register_routing_keys(self, routing_keys: List[str]): + async def validate_routing_keys(self, routing_keys: List[str]): for routing_key in routing_keys: if routing_key not in self.stack_to_provider_models_map: raise ValueError( f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}" ) - self.routing_keys = routing_keys - - def get_routing_keys(self) -> List[str]: - return self.routing_keys def map_to_provider_model(self, routing_key: str) -> str: model = resolve_model(routing_key)