clean and fix tests

This commit is contained in:
Swapna Lekkala 2025-10-10 15:18:56 -07:00
parent ad4362e48d
commit 171fb7101d
3 changed files with 3 additions and 44 deletions

View file

@ -15,15 +15,12 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem, ListOpenAIResponseInputItem,
ListOpenAIResponseObject, ListOpenAIResponseObject,
OpenAIDeleteResponseObject, OpenAIDeleteResponseObject,
OpenAIResponseContentPartRefusal,
OpenAIResponseInput, OpenAIResponseInput,
OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseMessage, OpenAIResponseMessage,
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseText, OpenAIResponseText,
OpenAIResponseTextFormat, OpenAIResponseTextFormat,
) )
@ -300,30 +297,6 @@ class OpenAIResponsesImpl:
raise ValueError("The response stream never reached a terminal state") raise ValueError("The response stream never reached a terminal state")
return final_response return final_response
async def _create_refusal_response_events(
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Create and yield refusal response events following the established streaming pattern."""
# Create initial response and yield created event
initial_response = OpenAIResponseObject(
id=response_id,
created_at=created_at,
model=model,
status="in_progress",
output=[],
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# Create completed refusal response using OpenAIResponseContentPartRefusal
refusal_response = OpenAIResponseObject(
id=response_id,
created_at=created_at,
model=model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
async def _create_streaming_response( async def _create_streaming_response(
self, self,
input: str | list[OpenAIResponseInput], input: str | list[OpenAIResponseInput],
@ -375,7 +348,6 @@ class OpenAIResponsesImpl:
shield_ids=shield_ids, shield_ids=shield_ids,
) )
# Output safety validation hook - delegated to streaming orchestrator for real-time validation
# Stream the response # Stream the response
final_response = None final_response = None
failed_response = None failed_response = None

View file

@ -147,20 +147,6 @@ class StreamingResponseOrchestrator:
refusal=e.violation.user_message or "Content blocked by safety shields" refusal=e.violation.user_message or "Content blocked by safety shields"
) )
async def _create_input_refusal_response_events(
self, refusal_content: OpenAIResponseContentPartRefusal
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Create refusal response events for input safety violations."""
# Create the refusal content part explicitly with the correct structure
refusal_response = OpenAIResponseObject(
id=self.response_id,
created_at=self.created_at,
model=self.ctx.model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None: async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
"""Check accumulated streaming text content against shields. Returns violation message if blocked.""" """Check accumulated streaming text content against shields. Returns violation message if blocked."""
if not self.shield_ids or not accumulated_text: if not self.shield_ids or not accumulated_text:
@ -237,8 +223,7 @@ class StreamingResponseOrchestrator:
input_refusal = await self._check_input_safety(self.ctx.messages) input_refusal = await self._check_input_safety(self.ctx.messages)
if input_refusal: if input_refusal:
# Return refusal response immediately # Return refusal response immediately
async for refusal_event in self._create_input_refusal_response_events(input_refusal): yield await self._create_refusal_response(input_refusal.refusal)
yield refusal_event
return return
async for stream_event in self._process_tools(output_messages): async for stream_event in self._process_tools(output_messages):

View file

@ -38,6 +38,7 @@ def responses_impl_with_conversations(
mock_responses_store, mock_responses_store,
mock_vector_io_api, mock_vector_io_api,
mock_conversations_api, mock_conversations_api,
mock_safety_api,
): ):
"""Create OpenAIResponsesImpl instance with conversations API.""" """Create OpenAIResponsesImpl instance with conversations API."""
return OpenAIResponsesImpl( return OpenAIResponsesImpl(
@ -47,6 +48,7 @@ def responses_impl_with_conversations(
responses_store=mock_responses_store, responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api, vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api, conversations_api=mock_conversations_api,
safety_api=mock_safety_api,
) )