[API Updates] Model / shield / memory-bank routing + agent persistence + support for private headers (#92)

This is yet another of those large PRs (hopefully we will have less and less of them as things mature fast). This one introduces substantial improvements and some simplifications to the stack.

Most important bits:

* Agents reference implementation now has support for session / turn persistence. The default implementation uses sqlite but there's also support for using Redis.

* We have re-architected the structure of the Stack APIs to allow for more flexible routing. The motivating use cases are:
  - routing model A to ollama and model B to a remote provider like Together
  - routing shield A to local impl while shield B to a remote provider like Bedrock
  - routing a vector memory bank to Weaviate while routing a keyvalue memory bank to Redis

* Support for provider specific parameters to be passed from the clients. A client can pass data using `x_llamastack_provider_data` parameter which can be type-checked and provided to the Adapter implementations.
This commit is contained in:
Ashwin Bharambe 2024-09-23 14:22:22 -07:00 committed by GitHub
parent 8bf8c07eb3
commit ec4fc800cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
130 changed files with 9701 additions and 11227 deletions

View file

@ -8,18 +8,14 @@ from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceImplConfig
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]
):
from .agents import MetaReferenceAgentsImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgentsImpl(
config,
deps[Api.inference],

View file

@ -25,10 +25,21 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
from .tools.base import BaseTool
from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool
from .tools.builtin import (
CodeInterpreterTool,
interpret_content_as_attachment,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import SafeTool
def make_random_string(length: int = 8):
@ -40,23 +51,44 @@ def make_random_string(length: int = 8):
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
agent_id: str,
agent_config: AgentConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool],
max_infer_iters: int = 10,
persistence_store: KVStore,
):
self.agent_id = agent_id
self.agent_config = agent_config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.max_infer_iters = max_infer_iters
self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.storage = AgentPersistence(agent_id, persistence_store)
self.tempdir = tempfile.mkdtemp()
self.sessions = {}
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=self.tempdir)
else:
continue
builtin_tools.append(
SafeTool(
tool,
safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
self.tools_dict = {t.get_name(): t for t in builtin_tools}
ShieldRunnerMixin.__init__(
self,
@ -80,7 +112,6 @@ class ChatAgent(ShieldRunnerMixin):
msg.context = None
messages.append(msg)
# messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
@ -94,43 +125,35 @@ class ChatAgent(ShieldRunnerMixin):
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
if step.violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
content=violation.user_message,
stop_reason=StopReason.end_of_turn,
)
)
# print_dialog(messages)
return messages
def create_session(self, name: str) -> Session:
session_id = str(uuid.uuid4())
session = Session(
session_id=session_id,
session_name=name,
turns=[],
started_at=datetime.now(),
)
self.sessions[session_id] = session
return session
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert (
request.session_id in self.sessions
), f"Session {request.session_id} not found"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
session = self.sessions[request.session_id]
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if len(session.turns) == 0 and self.agent_config.instructions != "":
if len(turns) == 0 and self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(session.turns):
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
@ -148,7 +171,7 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
output_message = None
async for chunk in self.run(
session=session,
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
@ -187,7 +210,7 @@ class ChatAgent(ShieldRunnerMixin):
completed_at=datetime.now(),
steps=steps,
)
session.turns.append(turn)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -200,7 +223,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run(
self,
session: Session,
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
@ -212,7 +235,7 @@ class ChatAgent(ShieldRunnerMixin):
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
@ -221,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
session, turn_id, input_messages, attachments, sampling_params, stream
session_id, turn_id, input_messages, attachments, sampling_params, stream
):
if isinstance(res, bool):
return
@ -235,7 +258,7 @@ class ChatAgent(ShieldRunnerMixin):
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
@ -245,11 +268,12 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response
async def run_shields_wrapper(
@tracing.span("run_shields")
async def run_multiple_shields_wrapper(
self,
turn_id: str,
messages: List[Message],
shields: List[ShieldDefinition],
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
if len(shields) == 0:
@ -266,7 +290,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
await self.run_shields(messages, shields)
await self.run_multiple_shields(messages, shields)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
@ -276,7 +300,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=e.response,
violation=e.violation,
),
)
)
@ -295,12 +319,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
violation=None,
),
)
)
@ -308,7 +327,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _run(
self,
session: Session,
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
@ -332,9 +351,10 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context(
session, input_messages, attachments
)
with tracing.span("retrieve_rag_context"):
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -387,55 +407,57 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = []
content = ""
stop_reason = None
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
with tracing.span("inference"):
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
tool_call_delta=delta,
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
tool_call_delta=delta,
)
)
)
)
elif isinstance(delta, str):
content += delta
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
elif isinstance(delta, str):
content += delta
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
)
)
)
)
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
if event.stop_reason is not None:
stop_reason = event.stop_reason
if event.stop_reason is not None:
stop_reason = event.stop_reason
stop_reason = stop_reason or StopReason.out_of_tokens
message = CompletionMessage(
@ -461,7 +483,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
if n_iter >= self.max_infer_iters:
if n_iter >= self.agent_config.max_infer_iters:
cprint("Done with MAX iterations, exiting.")
yield message
break
@ -512,14 +534,15 @@ class ChatAgent(ShieldRunnerMixin):
)
)
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
with tracing.span("tool_execution"):
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -550,12 +573,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
violation=None,
),
)
)
@ -569,7 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=e.response,
violation=e.violation,
),
)
)
@ -594,17 +612,25 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
if session.memory_bank is None:
session.memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session.session_id}",
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None:
memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session_id}",
config=VectorMemoryBankConfig(
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
bank_id = memory_bank.bank_id
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
return session.memory_bank
return bank_id
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
@ -619,7 +645,6 @@ class ChatAgent(ShieldRunnerMixin):
else:
return True
print(f"{enabled_tools=}")
return AgentTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
@ -630,7 +655,7 @@ class ChatAgent(ShieldRunnerMixin):
return None
async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment]
self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
bank_ids = []
@ -639,8 +664,8 @@ class ChatAgent(ShieldRunnerMixin):
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank = await self._ensure_memory_bank(session)
bank_ids.append(bank.bank_id)
bank_id = await self._ensure_memory_bank(session_id)
bank_ids.append(bank_id)
documents = [
MemoryBankDocument(
@ -651,9 +676,12 @@ class ChatAgent(ShieldRunnerMixin):
)
for a in attachments
]
await self.memory_api.insert_documents(bank.bank_id, documents)
elif session.memory_bank:
bank_ids.append(session.memory_bank.bank_id)
with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else:
session_info = await self.storage.get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)
if not bank_ids:
# this can happen if the per-session memory bank is not yet populated

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import tempfile
import uuid
from typing import AsyncGenerator
@ -15,28 +14,19 @@ from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403
from .agent_instance import ChatAgent
from .config import MetaReferenceImplConfig
from .tools.builtin import (
CodeInterpreterTool,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import with_safety
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
logger = logging.getLogger()
logger.setLevel(logging.INFO)
AGENT_INSTANCES_BY_ID = {}
class MetaReferenceAgentsImpl(Agents):
def __init__(
self,
config: MetaReferenceImplConfig,
config: MetaReferenceAgentsImplConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
@ -45,9 +35,10 @@ class MetaReferenceAgentsImpl(Agents):
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.in_memory_store = InmemoryKVStoreImpl()
async def initialize(self) -> None:
pass
self.persistence_store = await kvstore_impl(self.config.persistence_store)
async def create_agent(
self,
@ -55,38 +46,46 @@ class MetaReferenceAgentsImpl(Agents):
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
else:
continue
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.json(),
)
return AgentCreateResponse(
agent_id=agent_id,
)
builtin_tools.append(
with_safety(
tool,
self.safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
async def get_agent(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_config:
raise ValueError(f"Could not find agent config for {agent_id}")
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(
f"Could not JSON decode agent config for {agent_id}"
) from e
try:
agent_config = AgentConfig(**agent_config)
except Exception as e:
raise ValueError(
f"Could not validate(?) agent config for {agent_id}"
) from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,
builtin_tools=builtin_tools,
)
return AgentCreateResponse(
agent_id=agent_id,
persistence_store=(
self.persistence_store
if agent_config.enable_session_persistence
else self.in_memory_store
),
)
async def create_agent_session(
@ -94,12 +93,11 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
agent = await self.get_agent(agent_id)
session = agent.create_session(session_name)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
session_id=session.session_id,
session_id=session_id,
)
async def create_agent_turn(
@ -115,6 +113,8 @@ class MetaReferenceAgentsImpl(Agents):
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
agent = await self.get_agent(agent_id)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -124,12 +124,5 @@ class MetaReferenceAgentsImpl(Agents):
stream=stream,
)
agent_id = request.agent_id
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
assert (
request.session_id in agent.sessions
), f"Session {request.session_id} not found"
async for event in agent.create_and_execute_turn(request):
yield event

View file

@ -6,5 +6,8 @@
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStoreConfig
class MetaReferenceImplConfig(BaseModel): ...
class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from datetime import datetime
from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
started_at: datetime
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore):
self.agent_id = agent_id
self.kvstore = kvstore
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(),
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
)
return session_id
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
return None
return AgentSessionInfo(**json.loads(value))
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.json(),
)
async def get_session_turns(self, session_id: str) -> List[Turn]:
values = await self.kvstore.range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
turns = []
for value in values:
try:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
print(f"Error parsing turn: {e}")
continue
return turns

View file

@ -4,51 +4,48 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from llama_stack.apis.safety import (
OnViolationAction,
Safety,
ShieldDefinition,
ShieldResponse,
)
from llama_stack.apis.safety import * # noqa: F403
class SafetyException(Exception): # noqa: N818
def __init__(self, response: ShieldResponse):
self.response = response
super().__init__(response.violation_return_message)
def __init__(self, violation: SafetyViolation):
self.violation = violation
super().__init__(violation.user_message)
class ShieldRunnerMixin:
def __init__(
self,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
input_shields: List[str] = None,
output_shields: List[str] = None,
):
self.safety_api = safety_api
self.input_shields = input_shields
self.output_shields = output_shields
async def run_shields(
self, messages: List[Message], shields: List[ShieldDefinition]
) -> List[ShieldResponse]:
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
results = await self.safety_api.run_shields(
messages=messages,
shields=shields,
async def run_multiple_shields(
self, messages: List[Message], shields: List[str]
) -> None:
responses = await asyncio.gather(
*[
self.safety_api.run_shield(
shield_type=shield_type,
messages=messages,
)
for shield_type in shields
]
)
for shield, r in zip(shields, results):
if r.is_violation:
for shield, r in zip(shields, responses):
if r.violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
@ -56,5 +53,3 @@ class ShieldRunnerMixin:
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return results

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from typing import AsyncIterator, List, Optional, Union
from unittest.mock import MagicMock
import pytest
@ -79,10 +78,10 @@ class MockInferenceAPI:
class MockSafetyAPI:
async def run_shields(
self, messages: List[Message], shields: List[MagicMock]
) -> List[ShieldResponse]:
return [ShieldResponse(shield_type="mock_shield", is_violation=False)]
async def run_shield(
self, shield_type: str, messages: List[Message]
) -> RunShieldResponse:
return RunShieldResponse(violation=None)
class MockMemoryAPI:
@ -185,6 +184,7 @@ async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
# ),
],
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=[],
output_shields=[],
)
@ -221,13 +221,13 @@ async def test_chat_agent_create_and_execute_turn(chat_agent):
@pytest.mark.asyncio
async def test_run_shields_wrapper(chat_agent):
async def test_run_multiple_shields_wrapper(chat_agent):
messages = [UserMessage(content="Test message")]
shields = [ShieldDefinition(shield_type="test_shield")]
shields = ["test_shield"]
responses = [
chunk
async for chunk in chat_agent.run_shields_wrapper(
async for chunk in chat_agent.run_multiple_shields_wrapper(
turn_id="test_turn_id",
messages=messages,
shields=shields,

View file

@ -7,7 +7,7 @@
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, ShieldDefinition
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
@ -21,8 +21,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
self,
tool: BaseTool,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
input_shields: List[str] = None,
output_shields: List[str] = None,
):
self._tool = tool
ShieldRunnerMixin.__init__(
@ -30,29 +30,14 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
)
def get_name(self) -> str:
# return the name of the wrapped tool
return self._tool.get_name()
async def run(self, messages: List[Message]) -> List[Message]:
if self.input_shields:
await self.run_shields(messages, self.input_shields)
await self.run_multiple_shields(messages, self.input_shields)
# run the underlying tool
res = await self._tool.run(messages)
if self.output_shields:
await self.run_shields(messages, self.output_shields)
await self.run_multiple_shields(messages, self.output_shields)
return res
def with_safety(
tool: BaseTool,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
) -> SafeTool:
return SafeTool(
tool,
safety_api,
input_shields=input_shields,
output_shields=output_shields,
)

View file

@ -6,17 +6,14 @@
from typing import Optional
from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import all_registered_models, resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.inference import QuantizationConfig
@json_schema_type
class MetaReferenceImplConfig(BaseModel):
model: str = Field(
default="Meta-Llama3.1-8B-Instruct",
@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel):
m.descriptor()
for m in all_registered_models()
if m.model_family == ModelFamily.llama3_1
or m.core_model_id == CoreModelId.llama_guard_3_8b
]
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)

View file

@ -57,7 +57,7 @@ class MetaReferenceInferenceImpl(Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tools: Optional[List[ToolDefinition]] = [],
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
@ -70,7 +70,7 @@ class MetaReferenceInferenceImpl(Inference):
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,

View file

@ -42,7 +42,6 @@ class FaissIndex(EmbeddingIndex):
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
logger.info(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
self.id_by_index[indexlen + i] = chunk.document_id
self.index.add(np.array(embeddings).astype(np.float32))

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Optional
from llama_models.sku_list import CoreModelId, safety_models
@ -11,6 +12,13 @@ from llama_models.sku_list import CoreModelId, safety_models
from pydantic import BaseModel, validator
class MetaReferenceShieldType(Enum):
llama_guard = "llama_guard"
code_scanner_guard = "code_scanner_guard"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
class LlamaGuardShieldConfig(BaseModel):
model: str = "Llama-Guard-3-8B"
excluded_categories: List[str] = []

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from .config import MetaReferenceShieldType, SafetyConfig
from .config import SafetyConfig
from .shields import (
CodeScannerShield,
InjectionShield,
@ -19,7 +19,6 @@ from .shields import (
LlamaGuardShield,
PromptGuardShield,
ShieldBase,
ThirdPartyShield,
)
@ -50,46 +49,58 @@ class MetaReferenceSafetyImpl(Safety):
model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir)
async def run_shields(
async def run_shield(
self,
shield_type: str,
messages: List[Message],
shields: List[ShieldDefinition],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in shields]
available_shields = [v.value for v in MetaReferenceShieldType]
assert shield_type in available_shields, f"Unknown shield {shield_type}"
responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
return RunShieldResponse(responses=responses)
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages)
violation = None
if res.is_violation:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message=res.violation_return_message,
metadata={
"violation_type": res.violation_type,
},
)
def shield_type_equals(a: ShieldType, b: ShieldType):
return a == b or a == b.value
return RunShieldResponse(violation=violation)
def shield_config_to_shield(
sc: ShieldDefinition, safety_config: SafetyConfig
) -> ShieldBase:
if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
assert (
safety_config.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
return CodeScannerShield.instance()
elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
return ThirdPartyShield.instance()
else:
raise ValueError(f"Unknown shield type: {sc.shield_type}")
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config
if typ == MetaReferenceShieldType.llama_guard:
assert (
cfg.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif typ == MetaReferenceShieldType.jailbreak_shield:
assert (
cfg.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif typ == MetaReferenceShieldType.injection_shield:
assert (
cfg.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif typ == MetaReferenceShieldType.code_scanner_guard:
return CodeScannerShield.instance()
else:
raise ValueError(f"Unknown shield type: {typ}")

View file

@ -15,7 +15,6 @@ from .base import ( # noqa: F401
TextShield,
)
from .code_scanner import CodeScannerShield # noqa: F401
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,

View file

@ -8,11 +8,26 @@ from abc import ABC, abstractmethod
from typing import List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from pydantic import BaseModel
from llama_stack.apis.safety import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
# TODO: clean this up; just remove this type completely
class ShieldResponse(BaseModel):
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
# TODO: this is a caller / agent concern
class OnViolationAction(Enum):
IGNORE = 0
WARN = 1
RAISE = 2
class ShieldBase(ABC):
def __init__(
self,
@ -20,10 +35,6 @@ class ShieldBase(ABC):
):
self.on_violation_action = on_violation_action
@abstractmethod
def get_shield_type(self) -> ShieldType:
raise NotImplementedError()
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
@ -48,11 +59,6 @@ class TextShield(ShieldBase):
class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(
shield_type=BuiltinShield.third_party_shield, is_violation=False
)
return ShieldResponse(is_violation=False)

View file

@ -7,13 +7,9 @@
from termcolor import cprint
from .base import ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard
async def run_impl(self, text: str) -> ShieldResponse:
from codeshield.cs import CodeShield
@ -21,7 +17,6 @@ class CodeScannerShield(TextShield):
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard,
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
@ -29,6 +24,4 @@ class CodeScannerShield(TextShield):
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
)
return ShieldResponse(is_violation=False)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,35 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import Message
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
ShieldBase,
ShieldResponse,
)
_INSTANCE = None
class ThirdPartyShield(ShieldBase):
@staticmethod
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = ThirdPartyShield(on_violation_action)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
async def run(self, messages: List[Message]) -> ShieldResponse:
super.run() # will raise NotImplementedError

View file

@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_stack.apis.safety import * # noqa: F403
SAFE_RESPONSE = "safe"
_INSTANCE = None
@ -152,9 +152,6 @@ class LlamaGuardShield(ShieldBase):
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
@ -192,18 +189,13 @@ class LlamaGuardShield(ShieldBase):
def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(is_violation=False)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(is_violation=False)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
@ -213,12 +205,9 @@ class LlamaGuardShield(ShieldBase):
async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else:

View file

@ -13,7 +13,6 @@ from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
class PromptGuardShield(TextShield):
@ -74,13 +73,6 @@ class PromptGuardShield(TextShield):
self.threshold = threshold
self.mode = mode
def get_shield_type(self) -> ShieldType:
return (
BuiltinShield.jailbreak_shield
if self.mode == self.Mode.JAILBREAK
else BuiltinShield.injection_shield
)
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
@ -103,21 +95,18 @@ class PromptGuardShield(TextShield):
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=False,
)