mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
clean
This commit is contained in:
parent
1954b60600
commit
a720dbb942
2 changed files with 22 additions and 18 deletions
|
|
@ -22,6 +22,7 @@ from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.core.utils.model_utils import model_local_dir
|
from llama_stack.core.utils.model_utils import model_local_dir
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
|
|
||||||
from .config import PromptGuardConfig, PromptGuardType
|
from .config import PromptGuardConfig, PromptGuardType
|
||||||
|
|
||||||
|
|
@ -90,25 +91,9 @@ class PromptGuardShield:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
||||||
|
|
||||||
def _extract_text_from_openai_content(self, content) -> str:
|
|
||||||
"""Extract text content from OpenAI message content format."""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
elif isinstance(content, list):
|
|
||||||
text_parts = []
|
|
||||||
for part in content:
|
|
||||||
if hasattr(part, "type") and part.type == "text":
|
|
||||||
text_parts.append(part.text)
|
|
||||||
elif hasattr(part, "text"):
|
|
||||||
text_parts.append(part.text)
|
|
||||||
# Skip non-text parts like images or files
|
|
||||||
return " ".join(text_parts)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
|
||||||
|
|
||||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||||
message = messages[-1]
|
message = messages[-1]
|
||||||
text = self._extract_text_from_openai_content(message.content)
|
text = interleaved_content_as_str(message.content)
|
||||||
|
|
||||||
# run model on messages and return response
|
# run model on messages and return response
|
||||||
inputs = self.tokenizer(text, return_tensors="pt")
|
inputs = self.tokenizer(text, return_tensors="pt")
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,9 @@ from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIFile,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
|
@ -74,7 +77,17 @@ def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessag
|
||||||
return formatter.decode_assistant_message_from_content(content, stop_reason)
|
return formatter.decode_assistant_message_from_content(content, stop_reason)
|
||||||
|
|
||||||
|
|
||||||
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
def interleaved_content_as_str(
|
||||||
|
content: InterleavedContent
|
||||||
|
| str
|
||||||
|
| list[OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile]
|
||||||
|
| list[OpenAIChatCompletionContentPartTextParam]
|
||||||
|
| None,
|
||||||
|
sep: str = " ",
|
||||||
|
) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
def _process(c) -> str:
|
def _process(c) -> str:
|
||||||
if isinstance(c, str):
|
if isinstance(c, str):
|
||||||
return c
|
return c
|
||||||
|
|
@ -82,6 +95,12 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s
|
||||||
return "<image>"
|
return "<image>"
|
||||||
elif isinstance(c, TextContentItem):
|
elif isinstance(c, TextContentItem):
|
||||||
return c.text
|
return c.text
|
||||||
|
elif isinstance(c, OpenAIChatCompletionContentPartTextParam):
|
||||||
|
return c.text
|
||||||
|
elif isinstance(c, OpenAIChatCompletionContentPartImageParam):
|
||||||
|
return "<image>"
|
||||||
|
elif isinstance(c, OpenAIFile):
|
||||||
|
return "<file>"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported content type: {type(c)}")
|
raise ValueError(f"Unsupported content type: {type(c)}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue