mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-13 22:17:59 +00:00
chore!: Safety api refactoring to use OpenAIMessageParam (#3796)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m28s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m28s
# What does this PR do? Remove usage of deprecated `Message` from Safety apis ## Test Plan CI
This commit is contained in:
parent
82cbcada39
commit
3bb6ef351b
37 changed files with 2455 additions and 1050 deletions
|
@ -12,10 +12,9 @@ from typing import Any
|
|||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
Message,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
|
@ -165,7 +164,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
|
@ -175,8 +174,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
if len(messages) > 0 and messages[0].role != "user":
|
||||
messages[0] = OpenAIUserMessageParam(content=messages[0].content)
|
||||
|
||||
# Use the inference API's model resolution instead of hardcoded mappings
|
||||
# This allows the shield to work with any registered model
|
||||
|
@ -208,7 +207,7 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
messages = [input]
|
||||
|
||||
# convert to user messages format with role
|
||||
messages = [UserMessage(content=m) for m in messages]
|
||||
messages = [OpenAIUserMessageParam(content=m) for m in messages]
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
|
@ -277,7 +276,7 @@ class LlamaGuardShield:
|
|||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: list[Message]) -> None:
|
||||
def validate_messages(self, messages: list[OpenAIMessageParam]) -> list[OpenAIMessageParam]:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
|
@ -288,7 +287,7 @@ class LlamaGuardShield:
|
|||
|
||||
return messages
|
||||
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
|
@ -307,10 +306,10 @@ class LlamaGuardShield:
|
|||
content = content.strip()
|
||||
return self.get_shield_response(content)
|
||||
|
||||
def build_text_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
|
||||
return OpenAIUserMessageParam(role="user", content=self.build_prompt(messages))
|
||||
def build_text_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
|
||||
return OpenAIUserMessageParam(content=self.build_prompt(messages))
|
||||
|
||||
def build_vision_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
|
||||
def build_vision_shield_input(self, messages: list[OpenAIMessageParam]) -> OpenAIUserMessageParam:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
|
@ -333,7 +332,7 @@ class LlamaGuardShield:
|
|||
else:
|
||||
raise ValueError(f"Unknown content type: {c}")
|
||||
|
||||
conversation.append(UserMessage(content=content))
|
||||
conversation.append(OpenAIUserMessageParam(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {m.content}")
|
||||
|
||||
|
@ -342,9 +341,9 @@ class LlamaGuardShield:
|
|||
prompt.append(most_recent_img)
|
||||
prompt.append(self.build_prompt(conversation[::-1]))
|
||||
|
||||
return OpenAIUserMessageParam(role="user", content=prompt)
|
||||
return OpenAIUserMessageParam(content=prompt)
|
||||
|
||||
def build_prompt(self, messages: list[Message]) -> str:
|
||||
def build_prompt(self, messages: list[OpenAIMessageParam]) -> str:
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
|
@ -377,7 +376,7 @@ class LlamaGuardShield:
|
|||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
|
||||
async def run_moderation(self, messages: list[OpenAIMessageParam]) -> ModerationObject:
|
||||
if not messages:
|
||||
return self.create_moderation_object(self.model)
|
||||
|
||||
|
@ -388,6 +387,7 @@ class LlamaGuardShield:
|
|||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
temperature=0.0, # default is 1, which is too high for safety
|
||||
)
|
||||
response = await self.inference_api.openai_chat_completion(params)
|
||||
content = response.choices[0].message.content
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue