From a720dbb942536e50d923b6b982d19a88d2cc835c Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Sun, 12 Oct 2025 07:07:53 -0700 Subject: [PATCH] clean --- .../safety/prompt_guard/prompt_guard.py | 19 ++--------------- .../utils/inference/prompt_adapter.py | 21 ++++++++++++++++++- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 2ad63e5c1..8ca96300f 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -22,6 +22,7 @@ from llama_stack.apis.shields import Shield from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from .config import PromptGuardConfig, PromptGuardType @@ -90,25 +91,9 @@ class PromptGuardShield: self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device) - def _extract_text_from_openai_content(self, content) -> str: - """Extract text content from OpenAI message content format.""" - if isinstance(content, str): - return content - elif isinstance(content, list): - text_parts = [] - for part in content: - if hasattr(part, "type") and part.type == "text": - text_parts.append(part.text) - elif hasattr(part, "text"): - text_parts.append(part.text) - # Skip non-text parts like images or files - return " ".join(text_parts) - else: - raise ValueError(f"Unsupported content type: {type(content)}") - async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse: message = messages[-1] - text = self._extract_text_from_openai_content(message.content) + text = interleaved_content_as_str(message.content) # run model on messages and return response inputs = self.tokenizer(text, return_tensors="pt") diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 728bbf8c9..f02e5aeb6 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -23,6 +23,9 @@ from llama_stack.apis.inference import ( ChatCompletionRequest, CompletionRequest, Message, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIFile, ResponseFormat, ResponseFormatType, SystemMessage, @@ -74,7 +77,17 @@ def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessag return formatter.decode_assistant_message_from_content(content, stop_reason) -def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str: +def interleaved_content_as_str( + content: InterleavedContent + | str + | list[OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile] + | list[OpenAIChatCompletionContentPartTextParam] + | None, + sep: str = " ", +) -> str: + if content is None: + return "" + def _process(c) -> str: if isinstance(c, str): return c @@ -82,6 +95,12 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s return "" elif isinstance(c, TextContentItem): return c.text + elif isinstance(c, OpenAIChatCompletionContentPartTextParam): + return c.text + elif isinstance(c, OpenAIChatCompletionContentPartImageParam): + return "" + elif isinstance(c, OpenAIFile): + return "" else: raise ValueError(f"Unsupported content type: {type(c)}")