Introduce model_store, shield_store, memory_bank_store

This commit is contained in:
Ashwin Bharambe 2024-10-06 16:29:33 -07:00 committed by Ashwin Bharambe
parent e45a417543
commit 91e0063593
19 changed files with 172 additions and 297 deletions

View file

@ -173,7 +173,13 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class ModelStore(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
class Inference(Protocol):
model_store: ModelStore
@webmethod(route="/inference/completion")
async def completion(
self,
@ -207,9 +213,3 @@ class Inference(Protocol):
@webmethod(route="/inference/register_model")
async def register_model(self, model: ModelDef) -> None: ...
@webmethod(route="/inference/list_models")
async def list_models(self) -> List[ModelDef]: ...
@webmethod(route="/inference/get_model")
async def get_model(self, identifier: str) -> Optional[ModelDef]: ...

View file

@ -38,7 +38,13 @@ class QueryDocumentsResponse(BaseModel):
scores: List[float]
class MemoryBankStore(Protocol):
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ...
class Memory(Protocol):
memory_bank_store: MemoryBankStore
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/memory/insert")
@ -80,9 +86,3 @@ class Memory(Protocol):
@webmethod(route="/memory/register_memory_bank")
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
@webmethod(route="/memory/list_memory_banks")
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
@webmethod(route="/memory/get_memory_bank")
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]: ...

View file

@ -38,7 +38,13 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None
class ShieldStore(Protocol):
def get_shield(self, identifier: str) -> ShieldDef: ...
class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run_shield")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
@ -46,9 +52,3 @@ class Safety(Protocol):
@webmethod(route="/safety/register_shield")
async def register_shield(self, shield: ShieldDef) -> None: ...
@webmethod(route="/safety/list_shields")
async def list_shields(self) -> List[ShieldDef]: ...
@webmethod(route="/safety/get_shield")
async def get_shield(self, identifier: str) -> Optional[ShieldDef]: ...

View file

@ -39,6 +39,14 @@ RoutedProtocol = Union[
]
class ModelRegistry(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
class MemoryBankRegistry(Protocol):
def get_memory_bank(self, identifier: str) -> MemoryBankDef: ...
# Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"

View file

@ -151,6 +151,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
deps,
inner_impls,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else:

View file

@ -15,6 +15,20 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.inference:
await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
# TODO: this routing table maintains state in memory purely. We need to
# add persistence to it when we add dynamic registration of objects.
class CommonRoutingTableImpl(RoutingTable):
@ -32,6 +46,15 @@ class CommonRoutingTableImpl(RoutingTable):
self.impls_by_provider_id = impls_by_provider_id
self.registry = registry
for p in self.impls_by_provider_id.values():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.memory:
p.memory_bank_store = self
self.routing_key_to_object = {}
for obj in self.registry:
self.routing_key_to_object[obj.identifier] = obj
@ -39,7 +62,7 @@ class CommonRoutingTableImpl(RoutingTable):
async def initialize(self) -> None:
for obj in self.registry:
p = self.impls_by_provider_id[obj.provider_id]
await self.register_object(obj, p)
await register_object_with_provider(obj, p)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
@ -57,7 +80,7 @@ class CommonRoutingTableImpl(RoutingTable):
return obj
return None
async def register_object_common(self, obj: RoutableObject) -> None:
async def register_object(self, obj: RoutableObject) -> Any:
if obj.identifier in self.routing_key_to_object:
raise ValueError(f"Object `{obj.identifier}` already registered")
@ -65,16 +88,13 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
await p.register_object(obj)
await register_object_with_provider(obj, p)
self.routing_key_to_object[obj.identifier] = obj
self.registry.append(obj)
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def register_object(self, obj: ModelDef, p: Inference) -> None:
await p.register_model(obj)
async def list_models(self) -> List[ModelDef]:
return self.registry
@ -82,13 +102,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return self.get_object_by_identifier(identifier)
async def register_model(self, model: ModelDef) -> None:
await self.register_object_common(model)
await self.register_object(model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def register_object(self, obj: ShieldDef, p: Safety) -> None:
await p.register_shield(obj)
async def list_shields(self) -> List[ShieldDef]:
return self.registry
@ -96,13 +113,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return self.get_object_by_identifier(shield_type)
async def register_shield(self, shield: ShieldDef) -> None:
await self.register_object_common(shield)
await self.register_object(shield)
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def register_object(self, obj: MemoryBankDef, p: Memory) -> None:
await p.register_memory_bank(obj)
async def list_memory_banks(self) -> List[MemoryBankDef]:
return self.registry
@ -110,4 +124,4 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return self.get_object_by_identifier(identifier)
async def register_memory_bank(self, bank: MemoryBankDef) -> None:
await self.register_object_common(bank)
await self.register_object(bank)

View file

@ -6,39 +6,40 @@
from typing import AsyncGenerator
from openai import OpenAI
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import DatabricksImplConfig
DATABRICKS_SUPPORTED_MODELS = {
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
}
class DatabricksInferenceAdapter(Inference):
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> OpenAI:
return OpenAI(
base_url=self.config.url,
api_key=self.config.api_token
)
return OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def initialize(self) -> None:
return
@ -65,18 +66,6 @@ class DatabricksInferenceAdapter(Inference):
return databricks_messages
def resolve_databricks_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in DATABRICKS_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(DATABRICKS_SUPPORTED_MODELS.keys())}"
return DATABRICKS_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
@ -110,10 +99,9 @@ class DatabricksInferenceAdapter(Inference):
messages = augment_messages_for_tools(request)
options = self.get_databricks_chat_options(request)
databricks_model = self.resolve_databricks_model(request.model)
databricks_model = self.map_to_provider_model(request.model)
if not request.stream:
r = self.client.chat.completions.create(
model=databricks_model,
messages=self._messages_to_databricks_messages(messages),
@ -154,10 +142,7 @@ class DatabricksInferenceAdapter(Inference):
**options,
):
if chunk.choices[0].finish_reason:
if (
stop_reason is None
and chunk.choices[0].finish_reason == "stop"
):
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None

View file

@ -21,7 +21,7 @@ from llama_stack.providers.utils.inference.augment_messages import (
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
OLLAMA_SUPPORTED_SKUS = {
OLLAMA_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
@ -33,7 +33,7 @@ OLLAMA_SUPPORTED_SKUS = {
class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, url: str) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_SKUS
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS
)
self.url = url
tokenizer = Tokenizer.get_instance()

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleInferenceImpl(Inference, RoutableProvider):
class SampleInferenceImpl(Inference):
def __init__(self, config: SampleConfig):
self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
async def register_model(self, model: ModelDef) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -12,6 +12,7 @@ from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
@ -32,16 +33,18 @@ class _HfAdapter(Inference):
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
# TODO: make this work properly by checking this against the model_id being
# served by the remote endpoint
async def register_model(self, model: ModelDef) -> None:
pass
resolved_model = resolve_model(model.identifier)
if resolved_model is None:
raise ValueError(f"Unknown model: {model.identifier}")
async def list_models(self) -> List[ModelDef]:
return []
if not resolved_model.huggingface_repo:
raise ValueError(
f"Model {model.identifier} does not have a HuggingFace repo"
)
async def get_model(self, identifier: str) -> Optional[ModelDef]:
return None
if self.model_id != resolved_model.huggingface_repo:
raise ValueError(f"Model mismatch: {model.identifier} != {self.model_id}")
async def shutdown(self) -> None:
pass

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import uuid
from typing import List
from urllib.parse import urlparse
@ -13,7 +12,6 @@ import chromadb
from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
@ -65,7 +63,7 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class ChromaMemoryAdapter(Memory, RoutableProvider):
class ChromaMemoryAdapter(Memory):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
@ -93,48 +91,33 @@ class ChromaMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[chroma] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
async def register_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
memory_bank: MemoryBankDef,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
collection = await self.client.create_collection(
name=bank_id,
metadata={"bank": bank.json()},
name=memory_bank.identifier,
)
bank_index = BankWithIndex(
bank=bank, index=ChromaIndex(self.client, collection)
bank=memory_bank, index=ChromaIndex(self.client, collection)
)
self.cache[bank_id] = bank_index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
self.cache[memory_bank.identifier] = bank_index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if bank is None:
raise ValueError(f"Bank {bank_id} not found")
collections = await self.client.list_collections()
for collection in collections:
if collection.name == bank_id:
print(collection.metadata)
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),

View file

@ -4,18 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import List, Tuple
from typing import List
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
@ -32,33 +28,6 @@ def check_extension_version(cur):
return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, keys: List[str], cls):
query = "SELECT key, data FROM metadata_store"
if keys:
placeholders = ",".join(["%s"] * len(keys))
query += f" WHERE key IN ({placeholders})"
cur.execute(query, keys)
else:
cur.execute(query)
rows = cur.fetchall()
return [cls(**row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: MemoryBank, dimension: int, cursor):
self.cursor = cursor
@ -119,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class PGVectorMemoryAdapter(Memory, RoutableProvider):
class PGVectorMemoryAdapter(Memory):
def __init__(self, config: PGVectorConfig) -> None:
print(f"Initializing PGVectorMemoryAdapter -> {config.host}:{config.port}")
self.config = config
@ -144,14 +113,6 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
else:
raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e:
import traceback
@ -161,51 +122,28 @@ class PGVectorMemoryAdapter(Memory, RoutableProvider):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
print(f"[pgvector] Registering memory bank routing keys: {routing_keys}")
pass
async def create_memory_bank(
async def register_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
upsert_models(
self.cursor,
[
(bank.bank_id, bank),
],
)
memory_bank: MemoryBankDef,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
banks = load_models(self.cursor, [bank_id], MemoryBank)
if not banks:
return None
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
bank = banks[0]
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleMemoryImpl(Memory, RoutableProvider):
class SampleMemoryImpl(Memory):
def __init__(self, config: SampleConfig):
self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
# these are the memory banks the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -7,14 +7,12 @@
import json
import logging
import traceback
from typing import Any, Dict, List
import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from .config import BedrockSafetyConfig
@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
SUPPORTED_SHIELD_TYPES = [
"bedrock_guardrail",
BEDROCK_SUPPORTED_SHIELDS = [
ShieldType.generic_content_shield.value,
]
class BedrockSafetyAdapter(Safety, RoutableProvider):
class BedrockSafetyAdapter(Safety):
def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config
self.registered_shields = []
async def initialize(self) -> None:
try:
@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
shield_params = shield.params
if "guardrailIdentifier" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
@ -69,52 +79,38 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if "guardrailIdentifier" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
shield_params = shield_def.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get("guardrailIdentifier"),
guardrailVersion=params.get("guardrailVersion"),
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
logger.debug(f"run_shield:: response: {response}::")
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=shield_params["guardrailIdentifier"],
guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
except Exception:
error_str = traceback.format_exc()
logger.error(
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
return None

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleSafetyImpl(Safety, RoutableProvider):
class SampleSafetyImpl(Safety):
def __init__(self, config: SampleConfig):
self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
async def register_shield(self, shield: ShieldDef) -> None:
# these are the safety shields the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -12,7 +12,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig
SAFETY_SHIELD_MODEL_MAP = {
TOGETHER_SHIELD_MODEL_MAP = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
@ -22,7 +22,6 @@ SAFETY_SHIELD_MODEL_MAP = {
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config
self.register_shields = []
async def initialize(self) -> None:
pass
@ -34,26 +33,15 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
if shield.type != ShieldType.llama_guard.value:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
self.registered_shields.append(shield)
async def list_shields(self) -> List[ShieldDef]:
return self.registered_shields
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
for shield in self.registered_shields:
if shield.identifier == identifier:
return shield
return None
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield_def = await self.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
model = shield_def.params.get("model", "llama_guard")
if model not in SAFETY_SHIELD_MODEL_MAP:
if model not in TOGETHER_SHIELD_MODEL_MAP:
raise ValueError(f"Unsupported safety model: {model}")
together_api_key = None
@ -73,7 +61,9 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
if message.role in (Role.user.value, Role.assistant.value):
api_messages.append({"role": message.role, "content": message.content})
violation = await get_safety_response(together_api_key, model, api_messages)
violation = await get_safety_response(
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
)
return RunShieldResponse(violation=violation)

View file

@ -83,15 +83,6 @@ class FaissMemoryImpl(Memory):
)
self.cache[memory_bank.identifier] = index
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
index = self.cache.get(identifier)
if index is None:
return None
return index.bank
async def list_memory_banks(self) -> List[MemoryBankDef]:
return [x.bank for x in self.cache.values()]
async def insert_documents(
self,
bank_id: str,

View file

@ -33,7 +33,6 @@ class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
self.registered_shields = []
self.available_shields = [ShieldType.code_scanner.value]
if config.llama_guard_shield:
@ -55,24 +54,13 @@ class MetaReferenceSafetyImpl(Safety):
if shield.type not in self.available_shields:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
self.registered_shields.append(shield)
async def list_shields(self) -> List[ShieldDef]:
return self.registered_shields
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
for shield in self.registered_shields:
if shield.identifier == identifier:
return shield
return None
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List
from typing import Dict
from llama_models.sku_list import resolve_model
@ -15,7 +15,6 @@ class ModelRegistryHelper:
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
self.registered_models = []
def map_to_provider_model(self, identifier: str) -> str:
model = resolve_model(identifier)
@ -30,22 +29,7 @@ class ModelRegistryHelper:
return self.stack_to_provider_models_map[identifier]
async def register_model(self, model: ModelDef) -> None:
existing = await self.get_model(model.identifier)
if existing is not None:
return
if model.identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
)
self.registered_models.append(model)
async def list_models(self) -> List[ModelDef]:
return self.registered_models
async def get_model(self, identifier: str) -> Optional[ModelDef]:
for model in self.registered_models:
if model.identifier == identifier:
return model
return None