mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Kill get_routing_keys(), rename register_ to validate_
This commit is contained in:
parent
aab81cd5ad
commit
1702aa5e3f
11 changed files with 20 additions and 48 deletions
|
@ -44,7 +44,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
|
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await p.register_routing_keys(keys)
|
await p.validate_routing_keys(keys)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for _, p in self.unique_providers:
|
for _, p in self.unique_providers:
|
||||||
|
|
|
@ -34,13 +34,10 @@ class _HfAdapter(Inference, RoutableProvider):
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
async def register_routing_keys(self, routing_keys: list[str]) -> None:
|
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||||
# these are the model names the Llama Stack will use to route requests to this provider
|
# these are the model names the Llama Stack will use to route requests to this provider
|
||||||
# perform validation here if necessary
|
# perform validation here if necessary
|
||||||
self.routing_keys = routing_keys
|
pass
|
||||||
|
|
||||||
def get_routing_keys(self) -> list[str]:
|
|
||||||
return self.routing_keys
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -93,12 +93,9 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
|
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
|
||||||
self.routing_keys = routing_keys
|
pass
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
|
||||||
return self.routing_keys
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
async def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -161,12 +161,9 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
|
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
|
||||||
self.routing_keys = routing_keys
|
pass
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
|
||||||
return self.routing_keys
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
async def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -45,16 +45,11 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
for key in routing_keys:
|
for key in routing_keys:
|
||||||
if key not in SUPPORTED_SHIELD_TYPES:
|
if key not in SUPPORTED_SHIELD_TYPES:
|
||||||
raise ValueError(f"Unknown safety shield type: {key}")
|
raise ValueError(f"Unknown safety shield type: {key}")
|
||||||
|
|
||||||
self.routing_keys = routing_keys
|
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
|
||||||
return self.routing_keys
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
|
|
@ -35,14 +35,10 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
for key in routing_keys:
|
for key in routing_keys:
|
||||||
if key not in SAFETY_SHIELD_TYPES:
|
if key not in SAFETY_SHIELD_TYPES:
|
||||||
raise ValueError(f"Unknown safety shield type: {key}")
|
raise ValueError(f"Unknown safety shield type: {key}")
|
||||||
self.routing_keys = routing_keys
|
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
|
||||||
return self.routing_keys
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
|
|
|
@ -53,9 +53,13 @@ class RoutingTable(Protocol):
|
||||||
|
|
||||||
|
|
||||||
class RoutableProvider(Protocol):
|
class RoutableProvider(Protocol):
|
||||||
async def register_routing_keys(self, keys: List[str]) -> None: ...
|
"""
|
||||||
|
A provider which sits behind the RoutingTable and can get routed to.
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]: ...
|
All Inference / Safety / Memory providers fall into this bucket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def validate_routing_keys(self, keys: List[str]) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class GenericProviderConfig(BaseModel):
|
class GenericProviderConfig(BaseModel):
|
||||||
|
|
|
@ -38,15 +38,12 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
|
||||||
self.generator = LlamaModelParallelGenerator(self.config)
|
self.generator = LlamaModelParallelGenerator(self.config)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
|
|
||||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
assert (
|
assert (
|
||||||
len(routing_keys) == 1
|
len(routing_keys) == 1
|
||||||
), f"Only one routing key is supported {routing_keys}"
|
), f"Only one routing key is supported {routing_keys}"
|
||||||
assert routing_keys[0] == self.config.model
|
assert routing_keys[0] == self.config.model
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
|
||||||
return [self.config.model]
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
|
|
|
@ -72,12 +72,9 @@ class FaissMemoryImpl(Memory, RoutableProvider):
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
|
print(f"[faiss] Registering memory bank routing keys: {routing_keys}")
|
||||||
self.routing_keys = routing_keys
|
pass
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
|
||||||
return self.routing_keys
|
|
||||||
|
|
||||||
async def create_memory_bank(
|
async def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -51,15 +51,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_routing_keys(self, routing_keys: List[str]) -> None:
|
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
available_shields = [v.value for v in MetaReferenceShieldType]
|
available_shields = [v.value for v in MetaReferenceShieldType]
|
||||||
for key in routing_keys:
|
for key in routing_keys:
|
||||||
if key not in available_shields:
|
if key not in available_shields:
|
||||||
raise ValueError(f"Unknown safety shield type: {key}")
|
raise ValueError(f"Unknown safety shield type: {key}")
|
||||||
self.routing_keys = routing_keys
|
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
|
||||||
return self.routing_keys
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -16,16 +16,12 @@ class RoutableProviderForModels(RoutableProvider):
|
||||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||||
|
|
||||||
async def register_routing_keys(self, routing_keys: List[str]):
|
async def validate_routing_keys(self, routing_keys: List[str]):
|
||||||
for routing_key in routing_keys:
|
for routing_key in routing_keys:
|
||||||
if routing_key not in self.stack_to_provider_models_map:
|
if routing_key not in self.stack_to_provider_models_map:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}"
|
f"Routing key {routing_key} not found in map {self.stack_to_provider_models_map}"
|
||||||
)
|
)
|
||||||
self.routing_keys = routing_keys
|
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
|
||||||
return self.routing_keys
|
|
||||||
|
|
||||||
def map_to_provider_model(self, routing_key: str) -> str:
|
def map_to_provider_model(self, routing_key: str) -> str:
|
||||||
model = resolve_model(routing_key)
|
model = resolve_model(routing_key)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue