diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index e0b778345..549b1866c 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -40,6 +40,10 @@ class CommonRoutingTableImpl(RoutingTable): async def initialize(self) -> None: for keys, p in self.unique_providers: + spec = p.__provider_spec__ + if isinstance(spec, RemoteProviderSpec) and spec.adapter is None: + continue + await p.register_routing_keys(keys) async def shutdown(self) -> None: diff --git a/llama_stack/providers/adapters/memory/chroma/chroma.py b/llama_stack/providers/adapters/memory/chroma/chroma.py index 0a5f5bcd6..c741a61e3 100644 --- a/llama_stack/providers/adapters/memory/chroma/chroma.py +++ b/llama_stack/providers/adapters/memory/chroma/chroma.py @@ -13,7 +13,7 @@ import chromadb from numpy.typing import NDArray from llama_stack.apis.memory import * # noqa: F403 - +from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, @@ -65,7 +65,7 @@ class ChromaIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class ChromaMemoryAdapter(Memory): +class ChromaMemoryAdapter(Memory, RoutableProvider): def __init__(self, url: str) -> None: print(f"Initializing ChromaMemoryAdapter with url: {url}") url = url.rstrip("/") @@ -93,6 +93,13 @@ class ChromaMemoryAdapter(Memory): async def shutdown(self) -> None: pass + async def register_routing_keys(self, routing_keys: List[str]) -> None: + print(f"[chroma] Registering memory bank routing keys: {routing_keys}") + self.routing_keys = routing_keys + + def get_routing_keys(self) -> List[str]: + return self.routing_keys + async def create_memory_bank( self, name: str, diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py index 9cf0771ab..5b57b166a 100644 --- a/llama_stack/providers/adapters/memory/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -5,16 +5,17 @@ # the root directory of this source tree. import uuid - from typing import List, Tuple import psycopg2 from numpy.typing import NDArray from psycopg2 import sql from psycopg2.extras import execute_values, Json -from pydantic import BaseModel -from llama_stack.apis.memory import * # noqa: F403 +from pydantic import BaseModel + +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.providers.utils.memory.vector_store import ( ALL_MINILM_L6_V2_DIMENSION, @@ -118,7 +119,7 @@ class PGVectorIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class PGVectorMemoryAdapter(Memory): +class PGVectorMemoryAdapter(Memory, RoutableProvider): def __init__(self, config: PGVectorConfig) -> None: print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}") self.config = config @@ -160,6 +161,13 @@ class PGVectorMemoryAdapter(Memory): async def shutdown(self) -> None: pass + async def register_routing_keys(self, routing_keys: List[str]) -> None: + print(f"[pgvector] Registering memory bank routing keys: {routing_keys}") + self.routing_keys = routing_keys + + def get_routing_keys(self) -> List[str]: + return self.routing_keys + async def create_memory_bank( self, name: str, diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index a3acda1ce..d3eecc9c7 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -4,47 +4,63 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json +import logging import traceback from typing import Any, Dict, List -from .config import BedrockSafetyConfig +import boto3 + from llama_stack.apis.safety import * # noqa from llama_models.llama3.api.datatypes import * # noqa: F403 -import json -import logging +from llama_stack.distribution.datatypes import RoutableProvider -import boto3 +from .config import BedrockSafetyConfig logger = logging.getLogger(__name__) -class BedrockSafetyAdapter(Safety): +SUPPORTED_SHIELD_TYPES = [ + "bedrock_guardrail", +] + + +class BedrockSafetyAdapter(Safety, RoutableProvider): def __init__(self, config: BedrockSafetyConfig) -> None: + if not config.aws_profile: + raise ValueError(f"Missing boto_client aws_profile in model info::{config}") self.config = config async def initialize(self) -> None: - if not self.config.aws_profile: - raise RuntimeError( - f"Missing boto_client aws_profile in model info::{self.config}" - ) - try: - print(f"initializing with profile --- > {self.config}::") - self.boto_client_profile = self.config.aws_profile + print(f"initializing with profile --- > {self.config}") self.boto_client = boto3.Session( - profile_name=self.boto_client_profile + profile_name=self.config.aws_profile ).client("bedrock-runtime") except Exception as e: - raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e + raise RuntimeError("Error initializing BedrockSafetyAdapter") from e async def shutdown(self) -> None: pass + async def register_routing_keys(self, routing_keys: List[str]) -> None: + for key in routing_keys: + if key not in SUPPORTED_SHIELD_TYPES: + raise ValueError(f"Unknown safety shield type: {key}") + + self.routing_keys = routing_keys + + def get_routing_keys(self) -> List[str]: + return self.routing_keys + async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: + if shield_type not in SUPPORTED_SHIELD_TYPES: + raise ValueError(f"Unknown safety shield type: {shield_type}") + """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format ```content = [ { diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index 24fcc63b1..cb1040d19 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.sku_list import resolve_model from together import Together from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -13,43 +12,43 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) +from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.distribution.request_headers import NeedsRequestProviderData from .config import TogetherSafetyConfig + SAFETY_SHIELD_TYPES = { + "llama_guard": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", } -def shield_type_to_model_name(shield_type: str) -> str: - if shield_type == "llama_guard": - shield_type = "Llama-Guard-3-8B" - - model = resolve_model(shield_type) - if ( - model is None - or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES - or model.model_family is not ModelFamily.safety - ): - raise ValueError( - f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}" - ) - - return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True)) - - -class TogetherSafetyImpl(Safety, NeedsRequestProviderData): +class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider): def __init__(self, config: TogetherSafetyConfig) -> None: self.config = config async def initialize(self) -> None: pass + async def shutdown(self) -> None: + pass + + async def register_routing_keys(self, routing_keys: List[str]) -> None: + for key in routing_keys: + if key not in SAFETY_SHIELD_TYPES: + raise ValueError(f"Unknown safety shield type: {key}") + self.routing_keys = routing_keys + + def get_routing_keys(self) -> List[str]: + return self.routing_keys + async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: + if shield_type not in SAFETY_SHIELD_TYPES: + raise ValueError(f"Unknown safety shield type: {shield_type}") together_api_key = None provider_data = self.get_request_provider_data() @@ -59,7 +58,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData): ) together_api_key = provider_data.together_api_key - model_name = shield_type_to_model_name(shield_type) + model_name = SAFETY_SHIELD_TYPES[shield_type] # messages can have role assistant or user api_messages = [] diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index 30b7245e6..d79ef7b6f 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -14,6 +14,7 @@ import numpy as np from numpy.typing import NDArray from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.utils.memory.vector_store import ( @@ -62,7 +63,7 @@ class FaissIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class FaissMemoryImpl(Memory): +class FaissMemoryImpl(Memory, RoutableProvider): def __init__(self, config: FaissImplConfig) -> None: self.config = config self.cache = {} @@ -71,6 +72,13 @@ class FaissMemoryImpl(Memory): async def shutdown(self) -> None: ... + async def register_routing_keys(self, routing_keys: List[str]) -> None: + print(f"[faiss] Registering memory bank routing keys: {routing_keys}") + self.routing_keys = routing_keys + + def get_routing_keys(self) -> List[str]: + return self.routing_keys + async def create_memory_bank( self, name: str, diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 6bb851596..a2ce69880 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -4,13 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any, Dict, List + from llama_models.sku_list import resolve_model from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import Api +from llama_stack.distribution.datatypes import Api, RoutableProvider from llama_stack.providers.impls.meta_reference.safety.shields.base import ( OnViolationAction, @@ -35,7 +37,7 @@ def resolve_and_get_path(model_name: str) -> str: return model_dir -class MetaReferenceSafetyImpl(Safety): +class MetaReferenceSafetyImpl(Safety, RoutableProvider): def __init__(self, config: SafetyConfig, deps) -> None: self.config = config self.inference_api = deps[Api.inference] @@ -46,6 +48,19 @@ class MetaReferenceSafetyImpl(Safety): model_dir = resolve_and_get_path(shield_cfg.model) _ = PromptGuardShield.instance(model_dir) + async def shutdown(self) -> None: + pass + + async def register_routing_keys(self, routing_keys: List[str]) -> None: + available_shields = [v.value for v in MetaReferenceShieldType] + for key in routing_keys: + if key not in available_shields: + raise ValueError(f"Unknown safety shield type: {key}") + self.routing_keys = routing_keys + + def get_routing_keys(self) -> List[str]: + return self.routing_keys + async def run_shield( self, shield_type: str,