diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 5374f2efb..7ff70a2af 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -173,7 +173,13 @@ class EmbeddingsResponse(BaseModel): embeddings: List[List[float]] +class ModelStore(Protocol): + def get_model(self, identifier: str) -> ModelDef: ... + + class Inference(Protocol): + model_store: ModelStore + @webmethod(route="/inference/completion") async def completion( self, @@ -207,9 +213,3 @@ class Inference(Protocol): @webmethod(route="/inference/register_model") async def register_model(self, model: ModelDef) -> None: ... - - @webmethod(route="/inference/list_models") - async def list_models(self) -> List[ModelDef]: ... - - @webmethod(route="/inference/get_model") - async def get_model(self, identifier: str) -> Optional[ModelDef]: ... diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index 86dcbbcdc..c5161e864 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -38,7 +38,13 @@ class QueryDocumentsResponse(BaseModel): scores: List[float] +class MemoryBankStore(Protocol): + def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ... + + class Memory(Protocol): + memory_bank_store: MemoryBankStore + # this will just block now until documents are inserted, but it should # probably return a Job instance which can be polled for completion @webmethod(route="/memory/insert") @@ -80,9 +86,3 @@ class Memory(Protocol): @webmethod(route="/memory/register_memory_bank") async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ... - - @webmethod(route="/memory/list_memory_banks") - async def list_memory_banks(self) -> List[MemoryBankDef]: ... - - @webmethod(route="/memory/get_memory_bank") - async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ... diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index a3c94d136..4f4a49407 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -38,7 +38,13 @@ class RunShieldResponse(BaseModel): violation: Optional[SafetyViolation] = None +class ShieldStore(Protocol): + def get_shield(self, identifier: str) -> ShieldDef: ... + + class Safety(Protocol): + shield_store: ShieldStore + @webmethod(route="/safety/run_shield") async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None @@ -46,9 +52,3 @@ class Safety(Protocol): @webmethod(route="/safety/register_shield") async def register_shield(self, shield: ShieldDef) -> None: ... - - @webmethod(route="/safety/list_shields") - async def list_shields(self) -> List[ShieldDef]: ... - - @webmethod(route="/safety/get_shield") - async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index c987d4c87..f08eec462 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -39,6 +39,14 @@ RoutedProtocol = Union[ ] +class ModelRegistry(Protocol): + def get_model(self, identifier: str) -> ModelDef: ... + + +class MemoryBankRegistry(Protocol): + def get_memory_bank(self, identifier: str) -> MemoryBankDef: ... + + # Example: /inference, /safety class AutoRoutedProviderSpec(ProviderSpec): provider_type: str = "router" diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index d0c3adb84..0adb42915 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -151,6 +151,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An deps, inner_impls, ) + # TODO: ugh slightly redesign this shady looking code if "inner-" in api_str: inner_impls_by_provider_id[api_str][provider.provider_id] = impl else: diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index e51534446..ef38b6391 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -15,6 +15,20 @@ from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +def get_impl_api(p: Any) -> Api: + return p.__provider_spec__.api + + +async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: + api = get_impl_api(p) + if api == Api.inference: + await p.register_model(obj) + elif api == Api.safety: + await p.register_shield(obj) + elif api == Api.memory: + await p.register_memory_bank(obj) + + # TODO: this routing table maintains state in memory purely. We need to # add persistence to it when we add dynamic registration of objects. class CommonRoutingTableImpl(RoutingTable): @@ -32,6 +46,15 @@ class CommonRoutingTableImpl(RoutingTable): self.impls_by_provider_id = impls_by_provider_id self.registry = registry + for p in self.impls_by_provider_id.values(): + api = get_impl_api(p) + if api == Api.inference: + p.model_store = self + elif api == Api.safety: + p.shield_store = self + elif api == Api.memory: + p.memory_bank_store = self + self.routing_key_to_object = {} for obj in self.registry: self.routing_key_to_object[obj.identifier] = obj @@ -39,7 +62,7 @@ class CommonRoutingTableImpl(RoutingTable): async def initialize(self) -> None: for obj in self.registry: p = self.impls_by_provider_id[obj.provider_id] - await self.register_object(obj, p) + await register_object_with_provider(obj, p) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -57,7 +80,7 @@ class CommonRoutingTableImpl(RoutingTable): return obj return None - async def register_object_common(self, obj: RoutableObject) -> None: + async def register_object(self, obj: RoutableObject) -> Any: if obj.identifier in self.routing_key_to_object: raise ValueError(f"Object `{obj.identifier}` already registered") @@ -65,16 +88,13 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider `{obj.provider_id}` not found") p = self.impls_by_provider_id[obj.provider_id] - await p.register_object(obj) + await register_object_with_provider(obj, p) self.routing_key_to_object[obj.identifier] = obj self.registry.append(obj) class ModelsRoutingTable(CommonRoutingTableImpl, Models): - async def register_object(self, obj: ModelDef, p: Inference) -> None: - await p.register_model(obj) - async def list_models(self) -> List[ModelDef]: return self.registry @@ -82,13 +102,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return self.get_object_by_identifier(identifier) async def register_model(self, model: ModelDef) -> None: - await self.register_object_common(model) + await self.register_object(model) class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): - async def register_object(self, obj: ShieldDef, p: Safety) -> None: - await p.register_shield(obj) - async def list_shields(self) -> List[ShieldDef]: return self.registry @@ -96,13 +113,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return self.get_object_by_identifier(shield_type) async def register_shield(self, shield: ShieldDef) -> None: - await self.register_object_common(shield) + await self.register_object(shield) class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): - async def register_object(self, obj: MemoryBankDef, p: Memory) -> None: - await p.register_memory_bank(obj) - async def list_memory_banks(self) -> List[MemoryBankDef]: return self.registry @@ -110,4 +124,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): return self.get_object_by_identifier(identifier) async def register_memory_bank(self, bank: MemoryBankDef) -> None: - await self.register_object_common(bank) + await self.register_object(bank) diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index eeffb938d..6d106ccf1 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -6,39 +6,40 @@ from typing import AsyncGenerator -from openai import OpenAI - from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model + +from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from .config import DatabricksImplConfig + DATABRICKS_SUPPORTED_MODELS = { "Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct", "Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct", } -class DatabricksInferenceAdapter(Inference): +class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: DatabricksImplConfig) -> None: + ModelRegistryHelper.__init__( + self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS + ) self.config = config tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(tokenizer) @property def client(self) -> OpenAI: - return OpenAI( - base_url=self.config.url, - api_key=self.config.api_token - ) + return OpenAI(base_url=self.config.url, api_key=self.config.api_token) async def initialize(self) -> None: return @@ -65,18 +66,6 @@ class DatabricksInferenceAdapter(Inference): return databricks_messages - def resolve_databricks_model(self, model_name: str) -> str: - model = resolve_model(model_name) - assert ( - model is not None - and model.descriptor(shorten_default_variant=True) - in DATABRICKS_SUPPORTED_MODELS - ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(DATABRICKS_SUPPORTED_MODELS.keys())}" - - return DATABRICKS_SUPPORTED_MODELS.get( - model.descriptor(shorten_default_variant=True) - ) - def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} if request.sampling_params is not None: @@ -110,10 +99,9 @@ class DatabricksInferenceAdapter(Inference): messages = augment_messages_for_tools(request) options = self.get_databricks_chat_options(request) - databricks_model = self.resolve_databricks_model(request.model) + databricks_model = self.map_to_provider_model(request.model) if not request.stream: - r = self.client.chat.completions.create( model=databricks_model, messages=self._messages_to_databricks_messages(messages), @@ -154,10 +142,7 @@ class DatabricksInferenceAdapter(Inference): **options, ): if chunk.choices[0].finish_reason: - if ( - stop_reason is None - and chunk.choices[0].finish_reason == "stop" - ): + if stop_reason is None and chunk.choices[0].finish_reason == "stop": stop_reason = StopReason.end_of_turn elif ( stop_reason is None @@ -254,4 +239,4 @@ class DatabricksInferenceAdapter(Inference): delta="", stop_reason=stop_reason, ) - ) \ No newline at end of file + ) diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index aa9a25658..09af46b11 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -21,7 +21,7 @@ from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper -OLLAMA_SUPPORTED_SKUS = { +OLLAMA_SUPPORTED_MODELS = { "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16", @@ -33,7 +33,7 @@ OLLAMA_SUPPORTED_SKUS = { class OllamaInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, url: str) -> None: ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS + self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS ) self.url = url tokenizer = Tokenizer.get_instance() diff --git a/llama_stack/providers/adapters/inference/sample/sample.py b/llama_stack/providers/adapters/inference/sample/sample.py index 7d4e4a837..09171e395 100644 --- a/llama_stack/providers/adapters/inference/sample/sample.py +++ b/llama_stack/providers/adapters/inference/sample/sample.py @@ -9,14 +9,12 @@ from .config import SampleConfig from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.distribution.datatypes import RoutableProvider - -class SampleInferenceImpl(Inference, RoutableProvider): +class SampleInferenceImpl(Inference): def __init__(self, config: SampleConfig): self.config = config - async def validate_routing_keys(self, routing_keys: list[str]) -> None: + async def register_model(self, model: ModelDef) -> None: # these are the model names the Llama Stack will use to route requests to this provider # perform validation here if necessary pass diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 9868a9364..538c11ec7 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -12,6 +12,7 @@ from huggingface_hub import AsyncInferenceClient, HfApi from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.sku_list import resolve_model from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( @@ -32,16 +33,18 @@ class _HfAdapter(Inference): self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) - # TODO: make this work properly by checking this against the model_id being - # served by the remote endpoint async def register_model(self, model: ModelDef) -> None: - pass + resolved_model = resolve_model(model.identifier) + if resolved_model is None: + raise ValueError(f"Unknown model: {model.identifier}") - async def list_models(self) -> List[ModelDef]: - return [] + if not resolved_model.huggingface_repo: + raise ValueError( + f"Model {model.identifier} does not have a HuggingFace repo" + ) - async def get_model(self, identifier: str) -> Optional[ModelDef]: - return None + if self.model_id != resolved_model.huggingface_repo: + raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}") async def shutdown(self) -> None: pass diff --git a/llama_stack/providers/adapters/memory/chroma/chroma.py b/llama_stack/providers/adapters/memory/chroma/chroma.py index afa13111f..f720159a5 100644 --- a/llama_stack/providers/adapters/memory/chroma/chroma.py +++ b/llama_stack/providers/adapters/memory/chroma/chroma.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import uuid from typing import List from urllib.parse import urlparse @@ -13,7 +12,6 @@ import chromadb from numpy.typing import NDArray from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, @@ -65,7 +63,7 @@ class ChromaIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class ChromaMemoryAdapter(Memory, RoutableProvider): +class ChromaMemoryAdapter(Memory): def __init__(self, url: str) -> None: print(f"Initializing ChromaMemoryAdapter with url: {url}") url = url.rstrip("/") @@ -93,48 +91,33 @@ class ChromaMemoryAdapter(Memory, RoutableProvider): async def shutdown(self) -> None: pass - async def validate_routing_keys(self, routing_keys: List[str]) -> None: - print(f"[chroma] Registering memory bank routing keys: {routing_keys}") - pass - - async def create_memory_bank( + async def register_memory_bank( self, - name: str, - config: MemoryBankConfig, - url: Optional[URL] = None, - ) -> MemoryBank: - bank_id = str(uuid.uuid4()) - bank = MemoryBank( - bank_id=bank_id, - name=name, - config=config, - url=url, - ) + memory_bank: MemoryBankDef, + ) -> None: + assert ( + memory_bank.type == MemoryBankType.vector.value + ), f"Only vector banks are supported {memory_bank.type}" + collection = await self.client.create_collection( - name=bank_id, - metadata={"bank": bank.json()}, + name=memory_bank.identifier, ) bank_index = BankWithIndex( - bank=bank, index=ChromaIndex(self.client, collection) + bank=memory_bank, index=ChromaIndex(self.client, collection) ) - self.cache[bank_id] = bank_index - return bank - - async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: - bank_index = await self._get_and_cache_bank_index(bank_id) - if bank_index is None: - return None - return bank_index.bank + self.cache[memory_bank.identifier] = bank_index async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: if bank_id in self.cache: return self.cache[bank_id] + bank = await self.memory_bank_store.get_memory_bank(bank_id) + if bank is None: + raise ValueError(f"Bank {bank_id} not found") + collections = await self.client.list_collections() for collection in collections: if collection.name == bank_id: - print(collection.metadata) - bank = MemoryBank(**json.loads(collection.metadata["bank"])) index = BankWithIndex( bank=bank, index=ChromaIndex(self.client, collection), diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py index 5864aa7dc..c5dc1f4be 100644 --- a/llama_stack/providers/adapters/memory/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -4,18 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import uuid -from typing import List, Tuple +from typing import List import psycopg2 from numpy.typing import NDArray from psycopg2 import sql from psycopg2.extras import execute_values, Json -from pydantic import BaseModel - from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.distribution.datatypes import RoutableProvider from llama_stack.providers.utils.memory.vector_store import ( ALL_MINILM_L6_V2_DIMENSION, @@ -32,33 +28,6 @@ def check_extension_version(cur): return result[0] if result else None -def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]): - query = sql.SQL( - """ - INSERT INTO metadata_store (key, data) - VALUES %s - ON CONFLICT (key) DO UPDATE - SET data = EXCLUDED.data - """ - ) - - values = [(key, Json(model.dict())) for key, model in keys_models] - execute_values(cur, query, values, template="(%s, %s)") - - -def load_models(cur, keys: List[str], cls): - query = "SELECT key, data FROM metadata_store" - if keys: - placeholders = ",".join(["%s"] * len(keys)) - query += f" WHERE key IN ({placeholders})" - cur.execute(query, keys) - else: - cur.execute(query) - - rows = cur.fetchall() - return [cls(**row["data"]) for row in rows] - - class PGVectorIndex(EmbeddingIndex): def __init__(self, bank: MemoryBank, dimension: int, cursor): self.cursor = cursor @@ -119,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex): return QueryDocumentsResponse(chunks=chunks, scores=scores) -class PGVectorMemoryAdapter(Memory, RoutableProvider): +class PGVectorMemoryAdapter(Memory): def __init__(self, config: PGVectorConfig) -> None: print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}") self.config = config @@ -144,14 +113,6 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider): else: raise RuntimeError("Vector extension is not installed.") - self.cursor.execute( - """ - CREATE TABLE IF NOT EXISTS metadata_store ( - key TEXT PRIMARY KEY, - data JSONB - ) - """ - ) except Exception as e: import traceback @@ -161,51 +122,28 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider): async def shutdown(self) -> None: pass - async def validate_routing_keys(self, routing_keys: List[str]) -> None: - print(f"[pgvector] Registering memory bank routing keys: {routing_keys}") - pass - - async def create_memory_bank( + async def register_memory_bank( self, - name: str, - config: MemoryBankConfig, - url: Optional[URL] = None, - ) -> MemoryBank: - bank_id = str(uuid.uuid4()) - bank = MemoryBank( - bank_id=bank_id, - name=name, - config=config, - url=url, - ) - upsert_models( - self.cursor, - [ - (bank.bank_id, bank), - ], - ) + memory_bank: MemoryBankDef, + ) -> None: + assert ( + memory_bank.type == MemoryBankType.vector.value + ), f"Only vector banks are supported {memory_bank.type}" + index = BankWithIndex( - bank=bank, - index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), + bank=memory_bank, + index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), ) self.cache[bank_id] = index - return bank - - async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: - bank_index = await self._get_and_cache_bank_index(bank_id) - if bank_index is None: - return None - return bank_index.bank async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: if bank_id in self.cache: return self.cache[bank_id] - banks = load_models(self.cursor, [bank_id], MemoryBank) - if not banks: - return None + bank = await self.memory_bank_store.get_memory_bank(bank_id) + if not bank: + raise ValueError(f"Bank {bank_id} not found") - bank = banks[0] index = BankWithIndex( bank=bank, index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), diff --git a/llama_stack/providers/adapters/memory/sample/sample.py b/llama_stack/providers/adapters/memory/sample/sample.py index 7ef4a625d..3431b87d5 100644 --- a/llama_stack/providers/adapters/memory/sample/sample.py +++ b/llama_stack/providers/adapters/memory/sample/sample.py @@ -9,14 +9,12 @@ from .config import SampleConfig from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.distribution.datatypes import RoutableProvider - -class SampleMemoryImpl(Memory, RoutableProvider): +class SampleMemoryImpl(Memory): def __init__(self, config: SampleConfig): self.config = config - async def validate_routing_keys(self, routing_keys: list[str]) -> None: + async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: # these are the memory banks the Llama Stack will use to route requests to this provider # perform validation here if necessary pass diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index 814704e2c..7fbac2e4b 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -7,14 +7,12 @@ import json import logging -import traceback from typing import Any, Dict, List import boto3 from llama_stack.apis.safety import * # noqa from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import RoutableProvider from .config import BedrockSafetyConfig @@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig logger = logging.getLogger(__name__) -SUPPORTED_SHIELD_TYPES = [ - "bedrock_guardrail", +BEDROCK_SUPPORTED_SHIELDS = [ + ShieldType.generic_content_shield.value, ] -class BedrockSafetyAdapter(Safety, RoutableProvider): +class BedrockSafetyAdapter(Safety): def __init__(self, config: BedrockSafetyConfig) -> None: if not config.aws_profile: raise ValueError(f"Missing boto_client aws_profile in model info::{config}") self.config = config + self.registered_shields = [] async def initialize(self) -> None: try: @@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider): async def shutdown(self) -> None: pass - async def validate_routing_keys(self, routing_keys: List[str]) -> None: - for key in routing_keys: - if key not in SUPPORTED_SHIELD_TYPES: - raise ValueError(f"Unknown safety shield type: {key}") + async def register_shield(self, shield: ShieldDef) -> None: + if shield.type not in BEDROCK_SUPPORTED_SHIELDS: + raise ValueError(f"Unsupported safety shield type: {shield.type}") + + shield_params = shield.params + if "guardrailIdentifier" not in shield_params: + raise ValueError( + "Error running request for BedrockGaurdrails:Missing GuardrailID in request" + ) + + if "guardrailVersion" not in shield_params: + raise ValueError( + "Error running request for BedrockGaurdrails:Missing guardrailVersion in request" + ) async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - if shield_type not in SUPPORTED_SHIELD_TYPES: - raise ValueError(f"Unknown safety shield type: {shield_type}") + shield_def = await self.shield_store.get_shield(shield_type) + if not shield_def: + raise ValueError(f"Unknown shield {shield_type}") """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format ```content = [ @@ -69,52 +79,38 @@ class BedrockSafetyAdapter(Safety, RoutableProvider): They contain content, role . For now we will extract the content and default the "qualifiers": ["query"] """ - try: - logger.debug(f"run_shield::{params}::messages={messages}") - if "guardrailIdentifier" not in params: - raise RuntimeError( - "Error running request for BedrockGaurdrails:Missing GuardrailID in request" - ) - if "guardrailVersion" not in params: - raise RuntimeError( - "Error running request for BedrockGaurdrails:Missing guardrailVersion in request" - ) + shield_params = shield_def.params + logger.debug(f"run_shield::{shield_params}::messages={messages}") - # - convert the messages into format Bedrock expects - content_messages = [] - for message in messages: - content_messages.append({"text": {"text": message.content}}) - logger.debug( - f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:" - ) + # - convert the messages into format Bedrock expects + content_messages = [] + for message in messages: + content_messages.append({"text": {"text": message.content}}) + logger.debug( + f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:" + ) - response = self.boto_client.apply_guardrail( - guardrailIdentifier=params.get("guardrailIdentifier"), - guardrailVersion=params.get("guardrailVersion"), - source="OUTPUT", # or 'INPUT' depending on your use case - content=content_messages, - ) - logger.debug(f"run_shield:: response: {response}::") - if response["action"] == "GUARDRAIL_INTERVENED": - user_message = "" - metadata = {} - for output in response["outputs"]: - # guardrails returns a list - however for this implementation we will leverage the last values - user_message = output["text"] - for assessment in response["assessments"]: - # guardrails returns a list - however for this implementation we will leverage the last values - metadata = dict(assessment) - return SafetyViolation( - user_message=user_message, - violation_level=ViolationLevel.ERROR, - metadata=metadata, - ) + response = self.boto_client.apply_guardrail( + guardrailIdentifier=shield_params["guardrailIdentifier"], + guardrailVersion=shield_params["guardrailVersion"], + source="OUTPUT", # or 'INPUT' depending on your use case + content=content_messages, + ) + if response["action"] == "GUARDRAIL_INTERVENED": + user_message = "" + metadata = {} + for output in response["outputs"]: + # guardrails returns a list - however for this implementation we will leverage the last values + user_message = output["text"] + for assessment in response["assessments"]: + # guardrails returns a list - however for this implementation we will leverage the last values + metadata = dict(assessment) - except Exception: - error_str = traceback.format_exc() - logger.error( - f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!" + return SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, ) return None diff --git a/llama_stack/providers/adapters/safety/sample/sample.py b/llama_stack/providers/adapters/safety/sample/sample.py index a71f5143f..1aecf1ad0 100644 --- a/llama_stack/providers/adapters/safety/sample/sample.py +++ b/llama_stack/providers/adapters/safety/sample/sample.py @@ -9,14 +9,12 @@ from .config import SampleConfig from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.distribution.datatypes import RoutableProvider - -class SampleSafetyImpl(Safety, RoutableProvider): +class SampleSafetyImpl(Safety): def __init__(self, config: SampleConfig): self.config = config - async def validate_routing_keys(self, routing_keys: list[str]) -> None: + async def register_shield(self, shield: ShieldDef) -> None: # these are the safety shields the Llama Stack will use to route requests to this provider # perform validation here if necessary pass diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index 9d9fa6a4e..fa6ec395d 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -12,7 +12,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData from .config import TogetherSafetyConfig -SAFETY_SHIELD_MODEL_MAP = { +TOGETHER_SHIELD_MODEL_MAP = { "llama_guard": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", @@ -22,7 +22,6 @@ SAFETY_SHIELD_MODEL_MAP = { class TogetherSafetyImpl(Safety, NeedsRequestProviderData): def __init__(self, config: TogetherSafetyConfig) -> None: self.config = config - self.register_shields = [] async def initialize(self) -> None: pass @@ -34,26 +33,15 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData): if shield.type != ShieldType.llama_guard.value: raise ValueError(f"Unsupported safety shield type: {shield.type}") - self.registered_shields.append(shield) - - async def list_shields(self) -> List[ShieldDef]: - return self.registered_shields - - async def get_shield(self, identifier: str) -> Optional[ShieldDef]: - for shield in self.registered_shields: - if shield.identifier == identifier: - return shield - return None - async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - shield_def = await self.get_shield(shield_type) + shield_def = await self.shield_store.get_shield(shield_type) if not shield_def: raise ValueError(f"Unknown shield {shield_type}") model = shield_def.params.get("model", "llama_guard") - if model not in SAFETY_SHIELD_MODEL_MAP: + if model not in TOGETHER_SHIELD_MODEL_MAP: raise ValueError(f"Unsupported safety model: {model}") together_api_key = None @@ -73,7 +61,9 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData): if message.role in (Role.user.value, Role.assistant.value): api_messages.append({"role": message.role, "content": message.content}) - violation = await get_safety_response(together_api_key, model, api_messages) + violation = await get_safety_response( + together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages + ) return RunShieldResponse(violation=violation) diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index 1534971cd..7c59f5d59 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -83,15 +83,6 @@ class FaissMemoryImpl(Memory): ) self.cache[memory_bank.identifier] = index - async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: - index = self.cache.get(identifier) - if index is None: - return None - return index.bank - - async def list_memory_banks(self) -> List[MemoryBankDef]: - return [x.bank for x in self.cache.values()] - async def insert_documents( self, bank_id: str, diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 5154acd77..5d6747f9f 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -33,7 +33,6 @@ class MetaReferenceSafetyImpl(Safety): def __init__(self, config: SafetyConfig, deps) -> None: self.config = config self.inference_api = deps[Api.inference] - self.registered_shields = [] self.available_shields = [ShieldType.code_scanner.value] if config.llama_guard_shield: @@ -55,24 +54,13 @@ class MetaReferenceSafetyImpl(Safety): if shield.type not in self.available_shields: raise ValueError(f"Unsupported safety shield type: {shield.type}") - self.registered_shields.append(shield) - - async def list_shields(self) -> List[ShieldDef]: - return self.registered_shields - - async def get_shield(self, identifier: str) -> Optional[ShieldDef]: - for shield in self.registered_shields: - if shield.identifier == identifier: - return shield - return None - async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: - shield_def = await self.get_shield(shield_type) + shield_def = await self.shield_store.get_shield(shield_type) if not shield_def: raise ValueError(f"Unknown shield {shield_type}") diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index dabf698d4..744a89084 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List +from typing import Dict from llama_models.sku_list import resolve_model @@ -15,7 +15,6 @@ class ModelRegistryHelper: def __init__(self, stack_to_provider_models_map: Dict[str, str]): self.stack_to_provider_models_map = stack_to_provider_models_map - self.registered_models = [] def map_to_provider_model(self, identifier: str) -> str: model = resolve_model(identifier) @@ -30,22 +29,7 @@ class ModelRegistryHelper: return self.stack_to_provider_models_map[identifier] async def register_model(self, model: ModelDef) -> None: - existing = await self.get_model(model.identifier) - if existing is not None: - return - if model.identifier not in self.stack_to_provider_models_map: raise ValueError( f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}" ) - - self.registered_models.append(model) - - async def list_models(self) -> List[ModelDef]: - return self.registered_models - - async def get_model(self, identifier: str) -> Optional[ModelDef]: - for model in self.registered_models: - if model.identifier == identifier: - return model - return None