This commit is contained in:
Swapna Lekkala 2025-10-10 09:16:15 -07:00
parent e09401805f
commit b5c951fa4b
10 changed files with 40 additions and 178 deletions

View file

@ -133,6 +133,12 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
"""Refusal content within a streamed response part.
:param type: Content part type identifier, always "refusal"
:param refusal: Refusal text supplied by the model
"""
type: Literal["refusal"] = "refusal"
refusal: str
@ -884,18 +890,6 @@ class OpenAIResponseContentPartOutputText(BaseModel):
logprobs: list[dict[str, Any]] | None = None
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
"""Refusal content within a streamed response part.
:param type: Content part type identifier, always "refusal"
:param refusal: Refusal text supplied by the model
"""
type: Literal["refusal"] = "refusal"
refusal: str
@json_schema_type
class OpenAIResponseContentPartReasoningText(BaseModel):
"""Reasoning text emitted as part of a streamed response.

View file

@ -52,14 +52,6 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="openai_responses_utils")
# ============================================================================
# Message and Content Conversion Functions
# ============================================================================
async def convert_chat_choice_to_response_message(
@ -325,11 +317,6 @@ def is_function_tool_call(
return False
# ============================================================================
# Safety and Shield Validation Functions
# ============================================================================
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
"""Run multiple shields against messages and raise SafetyException for violations."""
if not shield_ids or not messages:
@ -359,7 +346,7 @@ def extract_shield_ids(shields: list | None) -> list[str]:
elif isinstance(shield, ResponseShieldSpec):
shield_ids.append(shield.type)
else:
logger.warning(f"Unknown shield format: {shield}")
raise ValueError(f"Unknown shield format: {shield}, expected str or ResponseShieldSpec")
return shield_ids