This commit is contained in:
Swapna Lekkala 2025-10-12 07:07:53 -07:00
parent 1954b60600
commit a720dbb942
2 changed files with 22 additions and 18 deletions

View file

@ -22,6 +22,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from .config import PromptGuardConfig, PromptGuardType
@ -90,25 +91,9 @@ class PromptGuardShield:
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
def _extract_text_from_openai_content(self, content) -> str:
"""Extract text content from OpenAI message content format."""
if isinstance(content, str):
return content
elif isinstance(content, list):
text_parts = []
for part in content:
if hasattr(part, "type") and part.type == "text":
text_parts.append(part.text)
elif hasattr(part, "text"):
text_parts.append(part.text)
# Skip non-text parts like images or files
return " ".join(text_parts)
else:
raise ValueError(f"Unsupported content type: {type(content)}")
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
message = messages[-1]
text = self._extract_text_from_openai_content(message.content)
text = interleaved_content_as_str(message.content)
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")

View file

@ -23,6 +23,9 @@ from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Message,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIFile,
ResponseFormat,
ResponseFormatType,
SystemMessage,
@ -74,7 +77,17 @@ def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessag
return formatter.decode_assistant_message_from_content(content, stop_reason)
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
def interleaved_content_as_str(
content: InterleavedContent
| str
| list[OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile]
| list[OpenAIChatCompletionContentPartTextParam]
| None,
sep: str = " ",
) -> str:
if content is None:
return ""
def _process(c) -> str:
if isinstance(c, str):
return c
@ -82,6 +95,12 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s
return "<image>"
elif isinstance(c, TextContentItem):
return c.text
elif isinstance(c, OpenAIChatCompletionContentPartTextParam):
return c.text
elif isinstance(c, OpenAIChatCompletionContentPartImageParam):
return "<image>"
elif isinstance(c, OpenAIFile):
return "<file>"
else:
raise ValueError(f"Unsupported content type: {type(c)}")