mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Kill get_routing_keys(), rename register_ to validate_
This commit is contained in:
parent
aab81cd5ad
commit
1702aa5e3f
11 changed files with 20 additions and 48 deletions
|
@ -44,7 +44,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
|
||||
continue
|
||||
|
||||
await p.register_routing_keys(keys)
|
||||
await p.validate_routing_keys(keys)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for _, p in self.unique_providers:
|
||||
|
|
|
@ -34,13 +34,10 @@ class _HfAdapter(Inference, RoutableProvider):
|
|||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def register_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
self.routing_keys = routing_keys
|
||||
|
||||
def get_routing_keys(self) -> list[str]:
|
||||
return self.routing_keys
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
|
@ -93,12 +93,9 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
|
||||
self.routing_keys = routing_keys
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
|
|
|
@ -161,12 +161,9 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
|
||||
self.routing_keys = routing_keys
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
|
|
|
@ -45,16 +45,11 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
|
||||
self.routing_keys = routing_keys
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
|
|
|
@ -35,14 +35,10 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
self.routing_keys = routing_keys
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
|
|
|
@ -53,9 +53,13 @@ class RoutingTable(Protocol):
|
|||
|
||||
|
||||
class RoutableProvider(Protocol):
|
||||
async def register_routing_keys(self, keys: List[str]) -> None: ...
|
||||
"""
|
||||
A provider which sits behind the RoutingTable and can get routed to.
|
||||
|
||||
def get_routing_keys(self) -> List[str]: ...
|
||||
All Inference / Safety / Memory providers fall into this bucket.
|
||||
"""
|
||||
|
||||
async def validate_routing_keys(self, keys: List[str]) -> None: ...
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
|
|
|
@ -38,15 +38,12 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
|||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
assert (
|
||||
len(routing_keys) == 1
|
||||
), f"Only one routing key is supported {routing_keys}"
|
||||
assert routing_keys[0] == self.config.model
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return [self.config.model]
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
|
|
|
@ -72,12 +72,9 @@ class FaissMemoryImpl(Memory, RoutableProvider):
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
|
||||
self.routing_keys = routing_keys
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
pass
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
|
|
|
@ -51,15 +51,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||
for key in routing_keys:
|
||||
if key not in available_shields:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
self.routing_keys = routing_keys
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
|
@ -16,16 +16,12 @@ class RoutableProviderForModels(RoutableProvider):
|
|||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||
|
||||
async def register_routing_keys(self, routing_keys: List[str]):
|
||||
async def validate_routing_keys(self, routing_keys: List[str]):
|
||||
for routing_key in routing_keys:
|
||||
if routing_key not in self.stack_to_provider_models_map:
|
||||
raise ValueError(
|
||||
f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}"
|
||||
)
|
||||
self.routing_keys = routing_keys
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
|
||||
def map_to_provider_model(self, routing_key: str) -> str:
|
||||
model = resolve_model(routing_key)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue