This commit is contained in:
Swapna Lekkala 2025-10-10 09:16:15 -07:00
parent 9152efa1a9
commit f820123b99
3 changed files with 3 additions and 11 deletions

View file

@ -134,6 +134,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
@json_schema_type @json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel): class OpenAIResponseContentPartRefusal(BaseModel):
"""Refusal content within a streamed response part. """Refusal content within a streamed response part.
:param type: Content part type identifier, always "refusal" :param type: Content part type identifier, always "refusal"
:param refusal: Refusal text supplied by the model :param refusal: Refusal text supplied by the model
""" """

View file

@ -312,6 +312,7 @@ def is_function_tool_call(
return True return True
return False return False
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None: async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
"""Run multiple shields against messages and raise SafetyException for violations.""" """Run multiple shields against messages and raise SafetyException for violations."""
if not shield_ids or not messages: if not shield_ids or not messages:
@ -340,7 +341,7 @@ def extract_shield_ids(shields: list | None) -> list[str]:
elif isinstance(shield, ResponseShieldSpec): elif isinstance(shield, ResponseShieldSpec):
shield_ids.append(shield.type) shield_ids.append(shield.type)
else: else:
raise ValueError(f"Unsupported shield type: {type(shield)}") raise ValueError(f"Unknown shield format: {shield}, expected str or ResponseShieldSpec")
return shield_ids return shield_ids

View file

@ -38,11 +38,6 @@ def responses_impl(mock_apis):
return OpenAIResponsesImpl(**mock_apis) return OpenAIResponsesImpl(**mock_apis)
# ============================================================================
# Shield ID Extraction Tests
# ============================================================================
def test_extract_shield_ids_from_strings(responses_impl): def test_extract_shield_ids_from_strings(responses_impl):
"""Test extraction from simple string shield IDs.""" """Test extraction from simple string shield IDs."""
shields = ["llama-guard", "content-filter", "nsfw-detector"] shields = ["llama-guard", "content-filter", "nsfw-detector"]
@ -92,11 +87,6 @@ def test_extract_shield_ids_unknown_format(responses_impl):
extract_shield_ids(shields) extract_shield_ids(shields)
# ============================================================================
# Text Content Extraction Tests
# ============================================================================
def test_extract_text_content_string(responses_impl): def test_extract_text_content_string(responses_impl):
"""Test extraction from simple string content.""" """Test extraction from simple string content."""
content = "Hello world" content = "Hello world"