mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 10:52:37 +00:00
remove recordings
This commit is contained in:
parent
82cbcada39
commit
67de6af0f0
36 changed files with 2453 additions and 1037 deletions
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
|
@ -22,9 +22,6 @@ from llama_stack.apis.shields import Shield
|
|||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
|
|
@ -56,7 +53,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
|
@ -93,9 +90,25 @@ class PromptGuardShield:
|
|||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
def _extract_text_from_openai_content(self, content) -> str:
|
||||
"""Extract text content from OpenAI message content format."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if hasattr(part, "type") and part.type == "text":
|
||||
text_parts.append(part.text)
|
||||
elif hasattr(part, "text"):
|
||||
text_parts.append(part.text)
|
||||
# Skip non-text parts like images or files
|
||||
return " ".join(text_parts)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
text = interleaved_content_as_str(message.content)
|
||||
text = self._extract_text_from_openai_content(message.content)
|
||||
|
||||
# run model on messages and return response
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue