mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 14:42:50 +00:00
remove recordings
This commit is contained in:
parent
82cbcada39
commit
67de6af0f0
36 changed files with 2453 additions and 1037 deletions
|
|
@ -9,7 +9,7 @@ from typing import Any, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
|
@ -107,7 +107,7 @@ class Safety(Protocol):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
"""Run shield.
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any
|
|||
if TYPE_CHECKING:
|
||||
from codeshield.cs import CodeShieldScanResult
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
|
@ -53,7 +53,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
|
|
|||
|
|
@ -12,10 +12,9 @@ from typing import Any
|
|||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
Message,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
|
|
@ -165,7 +164,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
|
@ -175,8 +174,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
if len(messages) > 0 and messages[0].role != "user":
|
||||
messages[0] = OpenAIUserMessageParam(content=messages[0].content)
|
||||
|
||||
# Use the inference API's model resolution instead of hardcoded mappings
|
||||
# This allows the shield to work with any registered model
|
||||
|
|
@ -208,7 +207,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
messages = [input]
|
||||
|
||||
# convert to user messages format with role
|
||||
messages = [UserMessage(content=m) for m in messages]
|
||||
messages = [OpenAIUserMessageParam(content=m) for m in messages]
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
|
|
@ -277,7 +276,7 @@ class LlamaGuardShield:
|
|||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: list[Message]) -> None:
|
||||
def validate_messages(self, messages: list[OpenAIMessageParam]) -> list[OpenAIMessageParam]:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
|
|
@ -288,7 +287,7 @@ class LlamaGuardShield:
|
|||
|
||||
return messages
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
|
|
@ -307,10 +306,10 @@ class LlamaGuardShield:
|
|||
content = content.strip()
|
||||
return self.get_shield_response(content)
|
||||
|
||||
def build_text_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
|
||||
return OpenAIUserMessageParam(role="user", content=self.build_prompt(messages))
|
||||
def build_text_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
|
||||
return OpenAIUserMessageParam(content=self.build_prompt(messages))
|
||||
|
||||
def build_vision_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
|
||||
def build_vision_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
|
|
@ -333,7 +332,7 @@ class LlamaGuardShield:
|
|||
else:
|
||||
raise ValueError(f"Unknown content type: {c}")
|
||||
|
||||
conversation.append(UserMessage(content=content))
|
||||
conversation.append(OpenAIUserMessageParam(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
|
|
@ -342,9 +341,9 @@ class LlamaGuardShield:
|
|||
prompt.append(most_recent_img)
|
||||
prompt.append(self.build_prompt(conversation[::-1]))
|
||||
|
||||
return OpenAIUserMessageParam(role="user", content=prompt)
|
||||
return OpenAIUserMessageParam(content=prompt)
|
||||
|
||||
def build_prompt(self, messages: list[Message]) -> str:
|
||||
def build_prompt(self, messages: list[OpenAIMessageParam]) -> str:
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
|
|
@ -377,7 +376,7 @@ class LlamaGuardShield:
|
|||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
|
||||
async def run_moderation(self, messages: list[OpenAIMessageParam]) -> ModerationObject:
|
||||
if not messages:
|
||||
return self.create_moderation_object(self.model)
|
||||
|
||||
|
|
@ -388,6 +387,7 @@ class LlamaGuardShield:
|
|||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
temperature=0.0, # default is 1, which is too high for safety
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
content = response.choices[0].message.content
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
|
@ -22,9 +22,6 @@ from llama_stack.apis.shields import Shield
|
|||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
|
|
@ -56,7 +53,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
|
@ -93,9 +90,25 @@ class PromptGuardShield:
|
|||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
def _extract_text_from_openai_content(self, content) -> str:
|
||||
"""Extract text content from OpenAI message content format."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if hasattr(part, "type") and part.type == "text":
|
||||
text_parts.append(part.text)
|
||||
elif hasattr(part, "text"):
|
||||
text_parts.append(part.text)
|
||||
# Skip non-text parts like images or files
|
||||
return " ".join(text_parts)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
text = interleaved_content_as_str(message.content)
|
||||
text = self._extract_text_from_openai_content(message.content)
|
||||
|
||||
# run model on messages and return response
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
|
@ -56,7 +56,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
|
||||
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -44,7 +44,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
|
||||
) -> RunShieldResponse:
|
||||
"""
|
||||
Run a safety shield check against the provided messages.
|
||||
|
|
@ -118,7 +118,7 @@ class NeMoGuardrails:
|
|||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
"""
|
||||
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Any
|
|||
import litellm
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
|
@ -72,7 +72,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
|||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||
self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue