From 4215cc9331b1daff089241d14c22244dec81ef07 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 5 Oct 2024 22:17:06 -0700 Subject: [PATCH] Push registration methods onto the backing providers --- llama_stack/apis/agents/agents.py | 2 +- llama_stack/apis/inference/inference.py | 10 + llama_stack/apis/memory/memory.py | 10 + llama_stack/apis/safety/safety.py | 10 + llama_stack/distribution/datatypes.py | 17 ++ llama_stack/distribution/resolver.py | 194 +++++++++--------- llama_stack/distribution/routers/__init__.py | 2 - .../distribution/routers/routing_tables.py | 63 +++--- .../adapters/safety/together/together.py | 46 +++-- llama_stack/providers/datatypes.py | 10 - .../impls/meta_reference/memory/faiss.py | 36 ++-- .../impls/meta_reference/safety/config.py | 8 +- .../impls/meta_reference/safety/safety.py | 79 ++++--- .../providers/utils/memory/vector_store.py | 2 +- 14 files changed, 269 insertions(+), 220 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index d008331d5..f9ad44efc 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -261,7 +261,7 @@ class Session(BaseModel): turns: List[Turn] started_at: datetime - memory_bank: Optional[MemoryBank] = None + memory_bank: Optional[MemoryBankDef] = None class AgentConfigCommon(BaseModel): diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 428f29b88..5374f2efb 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -14,6 +14,7 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.models import * # noqa: F403 class LogProbConfig(BaseModel): @@ -203,3 +204,12 @@ class Inference(Protocol): model: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: ... + + @webmethod(route="/inference/register_model") + async def register_model(self, model: ModelDef) -> None: ... + + @webmethod(route="/inference/list_models") + async def list_models(self) -> List[ModelDef]: ... + + @webmethod(route="/inference/get_model") + async def get_model(self, identifier: str) -> Optional[ModelDef]: ... diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index 8ac4a08a6..86dcbbcdc 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -15,6 +15,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.memory_banks import * # noqa: F403 @json_schema_type @@ -76,3 +77,12 @@ class Memory(Protocol): bank_id: str, document_ids: List[str], ) -> None: ... + + @webmethod(route="/memory/register_memory_bank") + async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ... + + @webmethod(route="/memory/list_memory_banks") + async def list_memory_banks(self) -> List[MemoryBankDef]: ... + + @webmethod(route="/memory/get_memory_bank") + async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ... diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index ed3a42f66..a3c94d136 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -11,6 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.shields import * # noqa: F403 @json_schema_type @@ -42,3 +43,12 @@ class Safety(Protocol): async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... + + @webmethod(route="/safety/register_shield") + async def register_shield(self, shield: ShieldDef) -> None: ... + + @webmethod(route="/safety/list_shields") + async def list_shields(self) -> List[ShieldDef]: ... + + @webmethod(route="/safety/get_shield") + async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 0ee03175c..05b2ad0d6 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -14,6 +14,9 @@ from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.inference import Inference +from llama_stack.apis.memory import Memory +from llama_stack.apis.safety import Safety LLAMA_STACK_BUILD_CONFIG_VERSION = "2" @@ -23,6 +26,19 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2" RoutingKey = Union[str, List[str]] +RoutableObject = Union[ + ModelDef, + ShieldDef, + MemoryBankDef, +] + +RoutedProtocol = Union[ + Inference, + Safety, + Memory, +] + + class GenericProviderConfig(BaseModel): provider_type: str config: Dict[str, Any] @@ -56,6 +72,7 @@ class RoutingTableProviderSpec(ProviderSpec): docker_image: Optional[str] = None router_api: Api + registry: List[RoutableObject] module: str pip_packages: List[str] = Field(default_factory=list) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index ec8374290..660d84fc8 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -28,46 +28,48 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An """ all_api_providers = get_provider_registry() - auto_routed_apis = builtin_automatically_routed_apis() + routing_table_apis = set( + x.routing_table_api for x in builtin_automatically_routed_apis() + ) + router_apis = set(x.router_api for x in builtin_automatically_routed_apis()) + providers_with_specs = {} - for api_str, instances in run_config.providers.items(): + for api_str, providers in run_config.providers.items(): api = Api(api_str) - if api in [a.routing_table_api for a in auto_routed_apis]: + if api in routing_table_apis: raise ValueError( f"Provider for `{api_str}` is automatically provided and cannot be overridden" ) - providers_with_specs[api] = {} - for config in instances: - if config.provider_type not in all_api_providers[api]: + specs = {} + for provider in providers: + if provider.provider_type not in all_api_providers[api]: raise ValueError( - f"Provider `{config.provider_type}` is not available for API `{api}`" + f"Provider `{provider.provider_type}` is not available for API `{api}`" ) spec = ProviderWithSpec( - spec=all_api_providers[api][config.provider_type], - **config, + spec=all_api_providers[api][provider.provider_type], + **(provider.dict()), ) - providers_with_specs[api][spec.provider_id] = spec + specs[provider.provider_id] = spec + + key = api_str if api not in router_apis else f"inner-{api_str}" + providers_with_specs[key] = specs apis_to_serve = run_config.apis_to_serve or set( - list(providers_with_specs.keys()) - + [a.routing_table_api.value for a in auto_routed_apis] + list(providers_with_specs.keys()) + list(routing_table_apis) ) + for info in builtin_automatically_routed_apis(): if info.router_api.value not in apis_to_serve: continue - if info.routing_table_api.value not in run_config: - raise ValueError( - f"Registry for `{info.routing_table_api.value}` is not provided?" - ) - - available_providers = providers_with_specs[info.router_api] + available_providers = providers_with_specs[f"inner-{info.router_api.value}"] inner_deps = [] - registry = run_config[info.routing_table_api.value] + registry = getattr(run_config, info.routing_table_api.value) for entry in registry: if entry.provider_id not in available_providers: raise ValueError( @@ -77,74 +79,70 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An provider = available_providers[entry.provider_id] inner_deps.extend(provider.spec.api_dependencies) - providers_with_specs[info.routing_table_api] = { - "__builtin__": [ - ProviderWithSpec( - provider_id="__builtin__", - provider_type="__builtin__", - config=registry, - spec=RoutingTableProviderSpec( - api=info.routing_table_api, - router_api=info.router_api, - module="llama_stack.distribution.routers", - api_dependencies=inner_deps, - ), - ) - ] + providers_with_specs[info.routing_table_api.value] = { + "__builtin__": ProviderWithSpec( + provider_id="__builtin__", + provider_type="__routing_table__", + config={}, + spec=RoutingTableProviderSpec( + api=info.routing_table_api, + router_api=info.router_api, + registry=registry, + module="llama_stack.distribution.routers", + api_dependencies=inner_deps, + ), + ) } - providers_with_specs[info.router_api] = { - "__builtin__": [ - ProviderWithSpec( - provider_id="__builtin__", - provider_type="__builtin__", - config={}, - spec=AutoRoutedProviderSpec( - api=info.router_api, - module="llama_stack.distribution.routers", - routing_table_api=source_api, - api_dependencies=[source_api], - ), - ) - ] + providers_with_specs[info.router_api.value] = { + "__builtin__": ProviderWithSpec( + provider_id="__builtin__", + provider_type="__autorouted__", + config={}, + spec=AutoRoutedProviderSpec( + api=info.router_api, + module="llama_stack.distribution.routers", + routing_table_api=info.routing_table_api, + api_dependencies=[info.routing_table_api], + ), + ) } - sorted_providers = topological_sort(providers_with_specs) + sorted_providers = topological_sort( + {k: v.values() for k, v in providers_with_specs.items()} + ) sorted_providers.append( - ProviderWithSpec( - provider_id="__builtin__", - provider_type="__builtin__", - config={}, - spec=InlineProviderSpec( - api=Api.inspect, + ( + "inspect", + ProviderWithSpec( + provider_id="__builtin__", provider_type="__builtin__", - config_class="llama_stack.distribution.inspect.DistributionInspectConfig", - module="llama_stack.distribution.inspect", + config={}, + spec=InlineProviderSpec( + api=Api.inspect, + provider_type="__builtin__", + config_class="llama_stack.distribution.inspect.DistributionInspectConfig", + module="llama_stack.distribution.inspect", + ), ), ) ) print(f"Resolved {len(sorted_providers)} providers in topological order") - for provider in sorted_providers: - print( - f" {provider.spec.api}: ({provider.provider_id}) {provider.spec.provider_type}" - ) + for api_str, provider in sorted_providers: + print(f" {api_str}: ({provider.provider_id}) {provider.spec.provider_type}") print("") + impls = {} - - impls_by_provider_id = {} - for provider in sorted_providers: - api = provider.spec.api - if api not in impls_by_provider_id: - impls_by_provider_id[api] = {} - - deps = {api: impls[api] for api in provider.spec.api_dependencies} + inner_impls_by_provider_id = {f"inner-{x}": {} for x in router_apis} + for api_str, provider in sorted_providers: + deps = {a: impls[a] for a in provider.spec.api_dependencies} inner_impls = {} if isinstance(provider.spec, RoutingTableProviderSpec): - for entry in provider.config: - inner_impls[entry.provider_id] = impls_by_provider_id[ - provider.spec.router_api + for entry in provider.spec.registry: + inner_impls[entry.provider_id] = inner_impls_by_provider_id[ + f"inner-{provider.spec.router_api.value}" ][entry.provider_id] impl = await instantiate_provider( @@ -152,37 +150,46 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An deps, inner_impls, ) - - impls[api] = impl - impls_by_provider_id[api][provider.provider_id] = impl + if "inner-" in api_str: + inner_impls_by_provider_id[api_str][provider.provider_id] = impl + else: + api = Api(api_str) + impls[api] = impl return impls def topological_sort( - providers_with_specs: Dict[Api, List[ProviderWithSpec]], + providers_with_specs: Dict[str, List[ProviderWithSpec]], ) -> List[ProviderWithSpec]: - def dfs(kv, visited: Set[Api], stack: List[Api]): - api, providers = kv - visited.add(api) + def dfs(kv, visited: Set[str], stack: List[str]): + api_str, providers = kv + visited.add(api_str) - deps = [dep for x in providers for dep in x.api_dependencies] - for api in deps: - if api not in visited: - dfs((api, providers_with_specs[api]), visited, stack) + deps = [] + for provider in providers: + for dep in provider.spec.api_dependencies: + deps.append(dep.value) + if isinstance(provider, AutoRoutedProviderSpec): + deps.append(f"inner-{provider.api}") - stack.append(api) + for dep in deps: + if dep not in visited: + dfs((dep, providers_with_specs[dep]), visited, stack) + + stack.append(api_str) visited = set() stack = [] - for api, providers in providers_with_specs.items(): - if api not in visited: - dfs((api, providers), visited, stack) + for api_str, providers in providers_with_specs.items(): + if api_str not in visited: + dfs((api_str, providers), visited, stack) flattened = [] - for api in stack: - flattened.extend(providers_with_specs[api]) + for api_str in stack: + for provider in providers_with_specs[api_str]: + flattened.append((api_str, provider)) return flattened @@ -202,9 +209,8 @@ async def instantiate_provider( else: method = "get_client_impl" - assert isinstance(provider_config, GenericProviderConfig) config_type = instantiate_class_type(provider_spec.config_class) - config = config_type(**provider_config.config) + config = config_type(**provider.config) args = [config, deps] elif isinstance(provider_spec, AutoRoutedProviderSpec): method = "get_auto_router_impl" @@ -214,17 +220,13 @@ async def instantiate_provider( elif isinstance(provider_spec, RoutingTableProviderSpec): method = "get_routing_table_impl" - assert isinstance(provider_config, list) - registry = provider_config - config = None - args = [provider_spec.api, registry, inner_impls, deps] + args = [provider_spec.api, provider_spec.registry, inner_impls, deps] else: method = "get_provider_impl" - assert isinstance(provider_config, GenericProviderConfig) config_type = instantiate_class_type(provider_spec.config_class) - config = config_type(**provider_config.config) + config = config_type(**provider.config) args = [config, deps] fn = getattr(module, method) diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 0464ab57a..9935ecd7d 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -10,8 +10,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403 from .routing_tables import ( MemoryBanksRoutingTable, ModelsRoutingTable, - RoutableObject, - RoutedProtocol, ShieldsRoutingTable, ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 01d92ff12..fbc3eae32 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -4,33 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, List, Optional, Union +from typing import Any, List, Optional from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.inference import Inference -from llama_stack.apis.memory import Memory -from llama_stack.apis.safety import Safety from llama_stack.distribution.datatypes import * # noqa: F403 -RoutableObject = Union[ - ModelDef, - ShieldDef, - MemoryBankDef, -] - -RoutedProtocol = Union[ - Inference, - Safety, - Memory, -] - - class CommonRoutingTableImpl(RoutingTable): def __init__( self, @@ -46,19 +30,14 @@ class CommonRoutingTableImpl(RoutingTable): self.impls_by_provider_id = impls_by_provider_id self.registry = registry - async def initialize(self) -> None: - keys_by_provider = {} + self.routing_key_to_object = {} for obj in self.registry: - keys = keys_by_provider.setdefault(obj.provider_id, []) - keys.append(obj.routing_key) + self.routing_key_to_object[obj.identifier] = obj - for provider_id, keys in keys_by_provider.items(): - p = self.impls_by_provider_id[provider_id] - spec = p.__provider_spec__ - if is_passthrough(spec): - continue - - await p.validate_routing_keys(keys) + async def initialize(self) -> None: + for obj in self.registry: + p = self.impls_by_provider_id[obj.provider_id] + await self.register_object(obj, p) async def shutdown(self) -> None: pass @@ -75,8 +54,24 @@ class CommonRoutingTableImpl(RoutingTable): return obj return None + def register_object(self, obj: RoutableObject) -> None: + if obj.identifier in self.routing_key_to_object: + raise ValueError(f"Object `{obj.identifier}` already registered") + + if obj.provider_id not in self.impls_by_provider_id: + raise ValueError(f"Provider `{obj.provider_id}` not found") + + p = self.impls_by_provider_id[obj.provider_id] + await p.register_object(obj) + + self.routing_key_to_object[obj.identifier] = obj + self.registry.append(obj) + class ModelsRoutingTable(CommonRoutingTableImpl, Models): + async def register_object(self, obj: ModelDef, p: Inference) -> None: + await p.register_model(obj) + async def list_models(self) -> List[ModelDef]: return self.registry @@ -84,10 +79,13 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return self.get_object_by_identifier(identifier) async def register_model(self, model: ModelDef) -> None: - raise NotImplementedError() + await self.register_object(model) class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): + async def register_object(self, obj: ShieldDef, p: Safety) -> None: + await p.register_shield(obj) + async def list_shields(self) -> List[ShieldDef]: return self.registry @@ -95,10 +93,13 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return self.get_object_by_identifier(shield_type) async def register_shield(self, shield: ShieldDef) -> None: - raise NotImplementedError() + await self.register_object(shield) class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): + async def register_object(self, obj: MemoryBankDef, p: Memory) -> None: + await p.register_memory_bank(obj) + async def list_memory_banks(self) -> List[MemoryBankDef]: return self.registry @@ -106,4 +107,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): return self.get_object_by_identifier(identifier) async def register_memory_bank(self, bank: MemoryBankDef) -> None: - raise NotImplementedError() + await self.register_object(bank) diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index c7a667e01..9d9fa6a4e 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -6,28 +6,23 @@ from together import Together from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.safety import ( - RunShieldResponse, - Safety, - SafetyViolation, - ViolationLevel, -) -from llama_stack.distribution.datatypes import RoutableProvider +from llama_stack.apis.safety import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData from .config import TogetherSafetyConfig -SAFETY_SHIELD_TYPES = { +SAFETY_SHIELD_MODEL_MAP = { "llama_guard": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", } -class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider): +class TogetherSafetyImpl(Safety, NeedsRequestProviderData): def __init__(self, config: TogetherSafetyConfig) -> None: self.config = config + self.register_shields = [] async def initialize(self) -> None: pass @@ -35,16 +30,31 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider): async def shutdown(self) -> None: pass - async def validate_routing_keys(self, routing_keys: List[str]) -> None: - for key in routing_keys: - if key not in SAFETY_SHIELD_TYPES: - raise ValueError(f"Unknown safety shield type: {key}") + async def register_shield(self, shield: ShieldDef) -> None: + if shield.type != ShieldType.llama_guard.value: + raise ValueError(f"Unsupported safety shield type: {shield.type}") + + self.registered_shields.append(shield) + + async def list_shields(self) -> List[ShieldDef]: + return self.registered_shields + + async def get_shield(self, identifier: str) -> Optional[ShieldDef]: + for shield in self.registered_shields: + if shield.identifier == identifier: + return shield + return None async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - if shield_type not in SAFETY_SHIELD_TYPES: - raise ValueError(f"Unknown safety shield type: {shield_type}") + shield_def = await self.get_shield(shield_type) + if not shield_def: + raise ValueError(f"Unknown shield {shield_type}") + + model = shield_def.params.get("model", "llama_guard") + if model not in SAFETY_SHIELD_MODEL_MAP: + raise ValueError(f"Unsupported safety model: {model}") together_api_key = None if self.config.api_key is not None: @@ -57,17 +67,13 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider): ) together_api_key = provider_data.together_api_key - model_name = SAFETY_SHIELD_TYPES[shield_type] - # messages can have role assistant or user api_messages = [] for message in messages: if message.role in (Role.user.value, Role.assistant.value): api_messages.append({"role": message.role, "content": message.content}) - violation = await get_safety_response( - together_api_key, model_name, api_messages - ) + violation = await get_safety_response(together_api_key, model, api_messages) return RunShieldResponse(violation=violation) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index abc1d601d..a254e2808 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -48,16 +48,6 @@ class RoutingTable(Protocol): def get_provider_impl(self, routing_key: str) -> Any: ... -class RoutableProvider(Protocol): - """ - A provider which sits behind the RoutingTable and can get routed to. - - All Inference / Safety / Memory providers fall into this bucket. - """ - - async def validate_routing_keys(self, keys: List[str]) -> None: ... - - @json_schema_type class AdapterSpec(BaseModel): adapter_type: str = Field( diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index b9a00908e..4f592e5e0 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import logging -import uuid from typing import Any, Dict, List, Optional @@ -72,38 +71,29 @@ class FaissMemoryImpl(Memory, RoutableProvider): async def shutdown(self) -> None: ... - async def validate_routing_keys(self, routing_keys: List[str]) -> None: - print(f"[faiss] Registering memory bank routing keys: {routing_keys}") - pass - - async def create_memory_bank( + async def register_memory_bank( self, - name: str, - config: MemoryBankConfig, - url: Optional[URL] = None, - ) -> MemoryBank: - assert url is None, "URL is not supported for this implementation" + memory_bank: MemoryBankDef, + ) -> None: assert ( - config.type == MemoryBankType.vector.value - ), f"Only vector banks are supported {config.type}" + memory_bank.type == MemoryBankType.vector.value + ), f"Only vector banks are supported {memory_bank.type}" - bank_id = str(uuid.uuid4()) - bank = MemoryBank( - bank_id=bank_id, - name=name, - config=config, - url=url, + index = BankWithIndex( + bank=memory_bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) ) - index = BankWithIndex(bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)) - self.cache[bank_id] = index + self.cache[memory_bank.identifier] = index return bank - async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: - index = self.cache.get(bank_id) + async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: + index = self.cache.get(identifier) if index is None: return None return index.bank + async def list_memory_banks(self) -> List[MemoryBankDef]: + return [x.bank for x in self.cache.values()] + async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/impls/meta_reference/safety/config.py b/llama_stack/providers/impls/meta_reference/safety/config.py index 51d2ae2bf..14233ad0c 100644 --- a/llama_stack/providers/impls/meta_reference/safety/config.py +++ b/llama_stack/providers/impls/meta_reference/safety/config.py @@ -12,11 +12,9 @@ from llama_models.sku_list import CoreModelId, safety_models from pydantic import BaseModel, field_validator -class MetaReferenceShieldType(Enum): - llama_guard = "llama_guard" - code_scanner_guard = "code_scanner_guard" - injection_shield = "injection_shield" - jailbreak_shield = "jailbreak_shield" +class PromptGuardType(Enum): + injection = "injection" + jailbreak = "jailbreak" class LlamaGuardShieldConfig(BaseModel): diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index bf19a3010..5154acd77 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -10,23 +10,36 @@ from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import Api, RoutableProvider +from llama_stack.distribution.datatypes import Api from llama_stack.providers.impls.meta_reference.safety.shields.base import ( OnViolationAction, ) -from .config import MetaReferenceShieldType, SafetyConfig +from .config import SafetyConfig -from .shields import CodeScannerShield, LlamaGuardShield, ShieldBase +from .shields import ( + CodeScannerShield, + InjectionShield, + JailbreakShield, + LlamaGuardShield, + ShieldBase, +) PROMPT_GUARD_MODEL = "Prompt-Guard-86M" -class MetaReferenceSafetyImpl(Safety, RoutableProvider): +class MetaReferenceSafetyImpl(Safety): def __init__(self, config: SafetyConfig, deps) -> None: self.config = config self.inference_api = deps[Api.inference] + self.registered_shields = [] + + self.available_shields = [ShieldType.code_scanner.value] + if config.llama_guard_shield: + self.available_shields.append(ShieldType.llama_guard.value) + if config.enable_prompt_guard: + self.available_shields.append(ShieldType.prompt_guard.value) async def initialize(self) -> None: if self.config.enable_prompt_guard: @@ -38,11 +51,20 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider): async def shutdown(self) -> None: pass - async def validate_routing_keys(self, routing_keys: List[str]) -> None: - available_shields = [v.value for v in MetaReferenceShieldType] - for key in routing_keys: - if key not in available_shields: - raise ValueError(f"Unknown safety shield type: {key}") + async def register_shield(self, shield: ShieldDef) -> None: + if shield.type not in self.available_shields: + raise ValueError(f"Unsupported safety shield type: {shield.type}") + + self.registered_shields.append(shield) + + async def list_shields(self) -> List[ShieldDef]: + return self.registered_shields + + async def get_shield(self, identifier: str) -> Optional[ShieldDef]: + for shield in self.registered_shields: + if shield.identifier == identifier: + return shield + return None async def run_shield( self, @@ -50,10 +72,11 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider): messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: - available_shields = [v.value for v in MetaReferenceShieldType] - assert shield_type in available_shields, f"Unknown shield {shield_type}" + shield_def = await self.get_shield(shield_type) + if not shield_def: + raise ValueError(f"Unknown shield {shield_type}") - shield = self.get_shield_impl(MetaReferenceShieldType(shield_type)) + shield = self.get_shield_impl(shield_def) messages = messages.copy() # some shields like llama-guard require the first message to be a user message @@ -79,30 +102,24 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider): return RunShieldResponse(violation=violation) - def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: - cfg = self.config - if typ == MetaReferenceShieldType.llama_guard: - cfg = cfg.llama_guard_shield - assert ( - cfg is not None - ), "Cannot use LlamaGuardShield since not present in config" - + def get_shield_impl(self, shield: ShieldDef) -> ShieldBase: + if shield.type == ShieldType.llama_guard.value: + cfg = self.config.llama_guard_shield return LlamaGuardShield( model=cfg.model, inference_api=self.inference_api, excluded_categories=cfg.excluded_categories, ) - elif typ == MetaReferenceShieldType.jailbreak_shield: - from .shields import JailbreakShield - + elif shield.type == ShieldType.prompt_guard.value: model_dir = model_local_dir(PROMPT_GUARD_MODEL) - return JailbreakShield.instance(model_dir) - elif typ == MetaReferenceShieldType.injection_shield: - from .shields import InjectionShield - - model_dir = model_local_dir(PROMPT_GUARD_MODEL) - return InjectionShield.instance(model_dir) - elif typ == MetaReferenceShieldType.code_scanner_guard: + subtype = shield.params.get("prompt_guard_type", "injection") + if subtype == "injection": + return InjectionShield.instance(model_dir) + elif subtype == "jailbreak": + return JailbreakShield.instance(model_dir) + else: + raise ValueError(f"Unknown prompt guard type: {subtype}") + elif shield.type == ShieldType.code_scanner.value: return CodeScannerShield.instance() else: - raise ValueError(f"Unknown shield type: {typ}") + raise ValueError(f"Unknown shield type: {shield.type}") diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 1683ddaa1..0540cdf60 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -146,7 +146,7 @@ class EmbeddingIndex(ABC): @dataclass class BankWithIndex: - bank: MemoryBank + bank: MemoryBankDef index: EmbeddingIndex async def insert_documents(