Introduce model_store, shield_store, memory_bank_store

This commit is contained in:
Ashwin Bharambe 2024-10-06 16:29:33 -07:00 committed by Ashwin Bharambe
parent e45a417543
commit 91e0063593
19 changed files with 172 additions and 297 deletions

View file

@ -173,7 +173,13 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]] embeddings: List[List[float]]
class ModelStore(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
class Inference(Protocol): class Inference(Protocol):
model_store: ModelStore
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
async def completion( async def completion(
self, self,
@ -207,9 +213,3 @@ class Inference(Protocol):
@webmethod(route="/inference/register_model") @webmethod(route="/inference/register_model")
async def register_model(self, model: ModelDef) -> None: ... async def register_model(self, model: ModelDef) -> None: ...
@webmethod(route="/inference/list_models")
async def list_models(self) -> List[ModelDef]: ...
@webmethod(route="/inference/get_model")
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...

View file

@ -38,7 +38,13 @@ class QueryDocumentsResponse(BaseModel):
scores: List[float] scores: List[float]
class MemoryBankStore(Protocol):
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
class Memory(Protocol): class Memory(Protocol):
memory_bank_store: MemoryBankStore
# this will just block now until documents are inserted, but it should # this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion # probably return a Job instance which can be polled for completion
@webmethod(route="/memory/insert") @webmethod(route="/memory/insert")
@ -80,9 +86,3 @@ class Memory(Protocol):
@webmethod(route="/memory/register_memory_bank") @webmethod(route="/memory/register_memory_bank")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ... async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
@webmethod(route="/memory/list_memory_banks")
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
@webmethod(route="/memory/get_memory_bank")
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...

View file

@ -38,7 +38,13 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None violation: Optional[SafetyViolation] = None
class ShieldStore(Protocol):
def get_shield(self, identifier: str) -> ShieldDef: ...
class Safety(Protocol): class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run_shield") @webmethod(route="/safety/run_shield")
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
@ -46,9 +52,3 @@ class Safety(Protocol):
@webmethod(route="/safety/register_shield") @webmethod(route="/safety/register_shield")
async def register_shield(self, shield: ShieldDef) -> None: ... async def register_shield(self, shield: ShieldDef) -> None: ...
@webmethod(route="/safety/list_shields")
async def list_shields(self) -> List[ShieldDef]: ...
@webmethod(route="/safety/get_shield")
async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ...

View file

@ -39,6 +39,14 @@ RoutedProtocol = Union[
] ]
class ModelRegistry(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
class MemoryBankRegistry(Protocol):
def get_memory_bank(self, identifier: str) -> MemoryBankDef: ...
# Example: /inference, /safety # Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec): class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router" provider_type: str = "router"

View file

@ -151,6 +151,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
deps, deps,
inner_impls, inner_impls,
) )
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str: if "inner-" in api_str:
inner_impls_by_provider_id[api_str][provider.provider_id] = impl inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else: else:

View file

@ -15,6 +15,20 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.inference:
await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
# TODO: this routing table maintains state in memory purely. We need to # TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects. # add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable): class CommonRoutingTableImpl(RoutingTable):
@ -32,6 +46,15 @@ class CommonRoutingTableImpl(RoutingTable):
self.impls_by_provider_id = impls_by_provider_id self.impls_by_provider_id = impls_by_provider_id
self.registry = registry self.registry = registry
for p in self.impls_by_provider_id.values():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.memory:
p.memory_bank_store = self
self.routing_key_to_object = {} self.routing_key_to_object = {}
for obj in self.registry: for obj in self.registry:
self.routing_key_to_object[obj.identifier] = obj self.routing_key_to_object[obj.identifier] = obj
@ -39,7 +62,7 @@ class CommonRoutingTableImpl(RoutingTable):
async def initialize(self) -> None: async def initialize(self) -> None:
for obj in self.registry: for obj in self.registry:
p = self.impls_by_provider_id[obj.provider_id] p = self.impls_by_provider_id[obj.provider_id]
await self.register_object(obj, p) await register_object_with_provider(obj, p)
async def shutdown(self) -> None: async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
@ -57,7 +80,7 @@ class CommonRoutingTableImpl(RoutingTable):
return obj return obj
return None return None
async def register_object_common(self, obj: RoutableObject) -> None: async def register_object(self, obj: RoutableObject) -> Any:
if obj.identifier in self.routing_key_to_object: if obj.identifier in self.routing_key_to_object:
raise ValueError(f"Object `{obj.identifier}` already registered") raise ValueError(f"Object `{obj.identifier}` already registered")
@ -65,16 +88,13 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider `{obj.provider_id}` not found") raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id] p = self.impls_by_provider_id[obj.provider_id]
await p.register_object(obj) await register_object_with_provider(obj, p)
self.routing_key_to_object[obj.identifier] = obj self.routing_key_to_object[obj.identifier] = obj
self.registry.append(obj) self.registry.append(obj)
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def register_object(self, obj: ModelDef, p: Inference) -> None:
await p.register_model(obj)
async def list_models(self) -> List[ModelDef]: async def list_models(self) -> List[ModelDef]:
return self.registry return self.registry
@ -82,13 +102,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return self.get_object_by_identifier(identifier) return self.get_object_by_identifier(identifier)
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
await self.register_object_common(model) await self.register_object(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def register_object(self, obj: ShieldDef, p: Safety) -> None:
await p.register_shield(obj)
async def list_shields(self) -> List[ShieldDef]: async def list_shields(self) -> List[ShieldDef]:
return self.registry return self.registry
@ -96,13 +113,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return self.get_object_by_identifier(shield_type) return self.get_object_by_identifier(shield_type)
async def register_shield(self, shield: ShieldDef) -> None: async def register_shield(self, shield: ShieldDef) -> None:
await self.register_object_common(shield) await self.register_object(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def register_object(self, obj: MemoryBankDef, p: Memory) -> None:
await p.register_memory_bank(obj)
async def list_memory_banks(self) -> List[MemoryBankDef]: async def list_memory_banks(self) -> List[MemoryBankDef]:
return self.registry return self.registry
@ -110,4 +124,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return self.get_object_by_identifier(identifier) return self.get_object_by_identifier(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None: async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.register_object_common(bank) await self.register_object(bank)

View file

@ -6,39 +6,40 @@
from typing import AsyncGenerator from typing import AsyncGenerator
from openai import OpenAI
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
) )
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import DatabricksImplConfig from .config import DatabricksImplConfig
DATABRICKS_SUPPORTED_MODELS = { DATABRICKS_SUPPORTED_MODELS = {
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct", "Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct", "Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
} }
class DatabricksInferenceAdapter(Inference): class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None: def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
)
self.config = config self.config = config
tokenizer = Tokenizer.get_instance() tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer) self.formatter = ChatFormat(tokenizer)
@property @property
def client(self) -> OpenAI: def client(self) -> OpenAI:
return OpenAI( return OpenAI(base_url=self.config.url, api_key=self.config.api_token)
base_url=self.config.url,
api_key=self.config.api_token
)
async def initialize(self) -> None: async def initialize(self) -> None:
return return
@ -65,18 +66,6 @@ class DatabricksInferenceAdapter(Inference):
return databricks_messages return databricks_messages
def resolve_databricks_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in DATABRICKS_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(DATABRICKS_SUPPORTED_MODELS.keys())}"
return DATABRICKS_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict: def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {} options = {}
if request.sampling_params is not None: if request.sampling_params is not None:
@ -110,10 +99,9 @@ class DatabricksInferenceAdapter(Inference):
messages = augment_messages_for_tools(request) messages = augment_messages_for_tools(request)
options = self.get_databricks_chat_options(request) options = self.get_databricks_chat_options(request)
databricks_model = self.resolve_databricks_model(request.model) databricks_model = self.map_to_provider_model(request.model)
if not request.stream: if not request.stream:
r = self.client.chat.completions.create( r = self.client.chat.completions.create(
model=databricks_model, model=databricks_model,
messages=self._messages_to_databricks_messages(messages), messages=self._messages_to_databricks_messages(messages),
@ -154,10 +142,7 @@ class DatabricksInferenceAdapter(Inference):
**options, **options,
): ):
if chunk.choices[0].finish_reason: if chunk.choices[0].finish_reason:
if ( if stop_reason is None and chunk.choices[0].finish_reason == "stop":
stop_reason is None
and chunk.choices[0].finish_reason == "stop"
):
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif ( elif (
stop_reason is None stop_reason is None

View file

@ -21,7 +21,7 @@ from llama_stack.providers.utils.inference.augment_messages import (
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
OLLAMA_SUPPORTED_SKUS = { OLLAMA_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16", "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
@ -33,7 +33,7 @@ OLLAMA_SUPPORTED_SKUS = {
class OllamaInferenceAdapter(ModelRegistryHelper, Inference): class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS
) )
self.url = url self.url = url
tokenizer = Tokenizer.get_instance() tokenizer = Tokenizer.get_instance()

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleInferenceImpl(Inference):
class SampleInferenceImpl(Inference, RoutableProvider):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None: async def register_model(self, model: ModelDef) -> None:
# these are the model names the Llama Stack will use to route requests to this provider # these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary # perform validation here if necessary
pass pass

View file

@ -12,6 +12,7 @@ from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
@ -32,16 +33,18 @@ class _HfAdapter(Inference):
self.tokenizer = Tokenizer.get_instance() self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer) self.formatter = ChatFormat(self.tokenizer)
# TODO: make this work properly by checking this against the model_id being
# served by the remote endpoint
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
pass resolved_model = resolve_model(model.identifier)
if resolved_model is None:
raise ValueError(f"Unknown model: {model.identifier}")
async def list_models(self) -> List[ModelDef]: if not resolved_model.huggingface_repo:
return [] raise ValueError(
f"Model {model.identifier} does not have a HuggingFace repo"
)
async def get_model(self, identifier: str) -> Optional[ModelDef]: if self.model_id != resolved_model.huggingface_repo:
return None raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import uuid
from typing import List from typing import List
from urllib.parse import urlparse from urllib.parse import urlparse
@ -13,7 +12,6 @@ import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
@ -65,7 +63,7 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory, RoutableProvider): class ChromaMemoryAdapter(Memory):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}") print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/") url = url.rstrip("/")
@ -93,48 +91,33 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_memory_bank(
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
self, self,
name: str, memory_bank: MemoryBankDef,
config: MemoryBankConfig, ) -> None:
url: Optional[URL] = None, assert (
) -> MemoryBank: memory_bank.type == MemoryBankType.vector.value
bank_id = str(uuid.uuid4()) ), f"Only vector banks are supported {memory_bank.type}"
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
collection = await self.client.create_collection( collection = await self.client.create_collection(
name=bank_id, name=memory_bank.identifier,
metadata={"bank": bank.json()},
) )
bank_index = BankWithIndex( bank_index = BankWithIndex(
bank=bank, index=ChromaIndex(self.client, collection) bank=memory_bank, index=ChromaIndex(self.client, collection)
) )
self.cache[bank_id] = bank_index self.cache[memory_bank.identifier] = bank_index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if bank is None:
raise ValueError(f"Bank {bank_id} not found")
collections = await self.client.list_collections() collections = await self.client.list_collections()
for collection in collections: for collection in collections:
if collection.name == bank_id: if collection.name == bank_id:
print(collection.metadata)
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=bank,
index=ChromaIndex(self.client, collection), index=ChromaIndex(self.client, collection),

View file

@ -4,18 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import uuid from typing import List
from typing import List, Tuple
import psycopg2 import psycopg2
from numpy.typing import NDArray from numpy.typing import NDArray
from psycopg2 import sql from psycopg2 import sql
from psycopg2.extras import execute_values, Json from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION, ALL_MINILM_L6_V2_DIMENSION,
@ -32,33 +28,6 @@ def check_extension_version(cur):
return result[0] if result else None return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, keys: List[str], cls):
query = "SELECT key, data FROM metadata_store"
if keys:
placeholders = ",".join(["%s"] * len(keys))
query += f" WHERE key IN ({placeholders})"
cur.execute(query, keys)
else:
cur.execute(query)
rows = cur.fetchall()
return [cls(**row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex): class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBank, dimension: int, cursor): def __init__(self, bank: MemoryBank, dimension: int, cursor):
self.cursor = cursor self.cursor = cursor
@ -119,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory, RoutableProvider): class PGVectorMemoryAdapter(Memory):
def __init__(self, config: PGVectorConfig) -> None: def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}") print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config self.config = config
@ -144,14 +113,6 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
else: else:
raise RuntimeError("Vector extension is not installed.") raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e: except Exception as e:
import traceback import traceback
@ -161,51 +122,28 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_memory_bank(
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
self, self,
name: str, memory_bank: MemoryBankDef,
config: MemoryBankConfig, ) -> None:
url: Optional[URL] = None, assert (
) -> MemoryBank: memory_bank.type == MemoryBankType.vector.value
bank_id = str(uuid.uuid4()) ), f"Only vector banks are supported {memory_bank.type}"
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
upsert_models(
self.cursor,
[
(bank.bank_id, bank),
],
)
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=memory_bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
) )
self.cache[bank_id] = index self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]
banks = load_models(self.cursor, [bank_id], MemoryBank) bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not banks: if not bank:
return None raise ValueError(f"Bank {bank_id} not found")
bank = banks[0]
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleMemoryImpl(Memory):
class SampleMemoryImpl(Memory, RoutableProvider):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None: async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
# these are the memory banks the Llama Stack will use to route requests to this provider # these are the memory banks the Llama Stack will use to route requests to this provider
# perform validation here if necessary # perform validation here if necessary
pass pass

View file

@ -7,14 +7,12 @@
import json import json
import logging import logging
import traceback
from typing import Any, Dict, List from typing import Any, Dict, List
import boto3 import boto3
from llama_stack.apis.safety import * # noqa from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from .config import BedrockSafetyConfig from .config import BedrockSafetyConfig
@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SUPPORTED_SHIELD_TYPES = [ BEDROCK_SUPPORTED_SHIELDS = [
"bedrock_guardrail", ShieldType.generic_content_shield.value,
] ]
class BedrockSafetyAdapter(Safety, RoutableProvider): class BedrockSafetyAdapter(Safety):
def __init__(self, config: BedrockSafetyConfig) -> None: def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile: if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}") raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config self.config = config
self.registered_shields = []
async def initialize(self) -> None: async def initialize(self) -> None:
try: try:
@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None: async def register_shield(self, shield: ShieldDef) -> None:
for key in routing_keys: if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
if key not in SUPPORTED_SHIELD_TYPES: raise ValueError(f"Unsupported safety shield type: {shield.type}")
raise ValueError(f"Unknown safety shield type: {key}")
shield_params = shield.params
if "guardrailIdentifier" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES: shield_def = await self.shield_store.get_shield(shield_type)
raise ValueError(f"Unknown safety shield type: {shield_type}") if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [ ```content = [
@ -69,17 +79,9 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"] They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
""" """
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if "guardrailIdentifier" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in params: shield_params = shield_def.params
raise RuntimeError( logger.debug(f"run_shield::{shield_params}::messages={messages}")
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
# - convert the messages into format Bedrock expects # - convert the messages into format Bedrock expects
content_messages = [] content_messages = []
@ -90,12 +92,11 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
) )
response = self.boto_client.apply_guardrail( response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get("guardrailIdentifier"), guardrailIdentifier=shield_params["guardrailIdentifier"],
guardrailVersion=params.get("guardrailVersion"), guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages, content=content_messages,
) )
logger.debug(f"run_shield:: response: {response}::")
if response["action"] == "GUARDRAIL_INTERVENED": if response["action"] == "GUARDRAIL_INTERVENED":
user_message = "" user_message = ""
metadata = {} metadata = {}
@ -105,16 +106,11 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
for assessment in response["assessments"]: for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values # guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment) metadata = dict(assessment)
return SafetyViolation( return SafetyViolation(
user_message=user_message, user_message=user_message,
violation_level=ViolationLevel.ERROR, violation_level=ViolationLevel.ERROR,
metadata=metadata, metadata=metadata,
) )
except Exception:
error_str = traceback.format_exc()
logger.error(
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
)
return None return None

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleSafetyImpl(Safety):
class SampleSafetyImpl(Safety, RoutableProvider):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None: async def register_shield(self, shield: ShieldDef) -> None:
# these are the safety shields the Llama Stack will use to route requests to this provider # these are the safety shields the Llama Stack will use to route requests to this provider
# perform validation here if necessary # perform validation here if necessary
pass pass

View file

@ -12,7 +12,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig from .config import TogetherSafetyConfig
SAFETY_SHIELD_MODEL_MAP = { TOGETHER_SHIELD_MODEL_MAP = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B", "llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
@ -22,7 +22,6 @@ SAFETY_SHIELD_MODEL_MAP = {
class TogetherSafetyImpl(Safety, NeedsRequestProviderData): class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
def __init__(self, config: TogetherSafetyConfig) -> None: def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config self.config = config
self.register_shields = []
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -34,26 +33,15 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
if shield.type != ShieldType.llama_guard.value: if shield.type != ShieldType.llama_guard.value:
raise ValueError(f"Unsupported safety shield type: {shield.type}") raise ValueError(f"Unsupported safety shield type: {shield.type}")
self.registered_shields.append(shield)
async def list_shields(self) -> List[ShieldDef]:
return self.registered_shields
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
for shield in self.registered_shields:
if shield.identifier == identifier:
return shield
return None
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:
shield_def = await self.get_shield(shield_type) shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def: if not shield_def:
raise ValueError(f"Unknown shield {shield_type}") raise ValueError(f"Unknown shield {shield_type}")
model = shield_def.params.get("model", "llama_guard") model = shield_def.params.get("model", "llama_guard")
if model not in SAFETY_SHIELD_MODEL_MAP: if model not in TOGETHER_SHIELD_MODEL_MAP:
raise ValueError(f"Unsupported safety model: {model}") raise ValueError(f"Unsupported safety model: {model}")
together_api_key = None together_api_key = None
@ -73,7 +61,9 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
if message.role in (Role.user.value, Role.assistant.value): if message.role in (Role.user.value, Role.assistant.value):
api_messages.append({"role": message.role, "content": message.content}) api_messages.append({"role": message.role, "content": message.content})
violation = await get_safety_response(together_api_key, model, api_messages) violation = await get_safety_response(
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
)
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)

View file

@ -83,15 +83,6 @@ class FaissMemoryImpl(Memory):
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
index = self.cache.get(identifier)
if index is None:
return None
return index.bank
async def list_memory_banks(self) -> List[MemoryBankDef]:
return [x.bank for x in self.cache.values()]
async def insert_documents( async def insert_documents(
self, self,
bank_id: str, bank_id: str,

View file

@ -33,7 +33,6 @@ class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig, deps) -> None: def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config self.config = config
self.inference_api = deps[Api.inference] self.inference_api = deps[Api.inference]
self.registered_shields = []
self.available_shields = [ShieldType.code_scanner.value] self.available_shields = [ShieldType.code_scanner.value]
if config.llama_guard_shield: if config.llama_guard_shield:
@ -55,24 +54,13 @@ class MetaReferenceSafetyImpl(Safety):
if shield.type not in self.available_shields: if shield.type not in self.available_shields:
raise ValueError(f"Unsupported safety shield type: {shield.type}") raise ValueError(f"Unsupported safety shield type: {shield.type}")
self.registered_shields.append(shield)
async def list_shields(self) -> List[ShieldDef]:
return self.registered_shields
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
for shield in self.registered_shields:
if shield.identifier == identifier:
return shield
return None
async def run_shield( async def run_shield(
self, self,
shield_type: str, shield_type: str,
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
shield_def = await self.get_shield(shield_type) shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def: if not shield_def:
raise ValueError(f"Unknown shield {shield_type}") raise ValueError(f"Unknown shield {shield_type}")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict, List from typing import Dict
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
@ -15,7 +15,6 @@ class ModelRegistryHelper:
def __init__(self, stack_to_provider_models_map: Dict[str, str]): def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map self.stack_to_provider_models_map = stack_to_provider_models_map
self.registered_models = []
def map_to_provider_model(self, identifier: str) -> str: def map_to_provider_model(self, identifier: str) -> str:
model = resolve_model(identifier) model = resolve_model(identifier)
@ -30,22 +29,7 @@ class ModelRegistryHelper:
return self.stack_to_provider_models_map[identifier] return self.stack_to_provider_models_map[identifier]
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
existing = await self.get_model(model.identifier)
if existing is not None:
return
if model.identifier not in self.stack_to_provider_models_map: if model.identifier not in self.stack_to_provider_models_map:
raise ValueError( raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}" f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
) )
self.registered_models.append(model)
async def list_models(self) -> List[ModelDef]:
return self.registered_models
async def get_model(self, identifier: str) -> Optional[ModelDef]:
for model in self.registered_models:
if model.identifier == identifier:
return model
return None