mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 17:59:29 +00:00
clean up
This commit is contained in:
parent
9152efa1a9
commit
f820123b99
3 changed files with 3 additions and 11 deletions
|
|
@ -134,6 +134,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseContentPartRefusal(BaseModel):
|
class OpenAIResponseContentPartRefusal(BaseModel):
|
||||||
"""Refusal content within a streamed response part.
|
"""Refusal content within a streamed response part.
|
||||||
|
|
||||||
:param type: Content part type identifier, always "refusal"
|
:param type: Content part type identifier, always "refusal"
|
||||||
:param refusal: Refusal text supplied by the model
|
:param refusal: Refusal text supplied by the model
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -312,6 +312,7 @@ def is_function_tool_call(
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
|
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
|
||||||
"""Run multiple shields against messages and raise SafetyException for violations."""
|
"""Run multiple shields against messages and raise SafetyException for violations."""
|
||||||
if not shield_ids or not messages:
|
if not shield_ids or not messages:
|
||||||
|
|
@ -340,7 +341,7 @@ def extract_shield_ids(shields: list | None) -> list[str]:
|
||||||
elif isinstance(shield, ResponseShieldSpec):
|
elif isinstance(shield, ResponseShieldSpec):
|
||||||
shield_ids.append(shield.type)
|
shield_ids.append(shield.type)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported shield type: {type(shield)}")
|
raise ValueError(f"Unknown shield format: {shield}, expected str or ResponseShieldSpec")
|
||||||
|
|
||||||
return shield_ids
|
return shield_ids
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,11 +38,6 @@ def responses_impl(mock_apis):
|
||||||
return OpenAIResponsesImpl(**mock_apis)
|
return OpenAIResponsesImpl(**mock_apis)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Shield ID Extraction Tests
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_shield_ids_from_strings(responses_impl):
|
def test_extract_shield_ids_from_strings(responses_impl):
|
||||||
"""Test extraction from simple string shield IDs."""
|
"""Test extraction from simple string shield IDs."""
|
||||||
shields = ["llama-guard", "content-filter", "nsfw-detector"]
|
shields = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||||
|
|
@ -92,11 +87,6 @@ def test_extract_shield_ids_unknown_format(responses_impl):
|
||||||
extract_shield_ids(shields)
|
extract_shield_ids(shields)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Text Content Extraction Tests
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_text_content_string(responses_impl):
|
def test_extract_text_content_string(responses_impl):
|
||||||
"""Test extraction from simple string content."""
|
"""Test extraction from simple string content."""
|
||||||
content = "Hello world"
|
content = "Hello world"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue