clean and fix tests

This commit is contained in:
Swapna Lekkala 2025-10-10 15:18:56 -07:00
parent ad4362e48d
commit 171fb7101d
3 changed files with 3 additions and 44 deletions

View file

@ -15,15 +15,12 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseContentPartRefusal,
OpenAIResponseInput,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
@ -300,30 +297,6 @@ class OpenAIResponsesImpl:
raise ValueError("The response stream never reached a terminal state")
return final_response
async def _create_refusal_response_events(
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Create and yield refusal response events following the established streaming pattern."""
# Create initial response and yield created event
initial_response = OpenAIResponseObject(
id=response_id,
created_at=created_at,
model=model,
status="in_progress",
output=[],
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# Create completed refusal response using OpenAIResponseContentPartRefusal
refusal_response = OpenAIResponseObject(
id=response_id,
created_at=created_at,
model=model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
@ -375,7 +348,6 @@ class OpenAIResponsesImpl:
shield_ids=shield_ids,
)
# Output safety validation hook - delegated to streaming orchestrator for real-time validation
# Stream the response
final_response = None
failed_response = None

View file

@ -147,20 +147,6 @@ class StreamingResponseOrchestrator:
refusal=e.violation.user_message or "Content blocked by safety shields"
)
async def _create_input_refusal_response_events(
self, refusal_content: OpenAIResponseContentPartRefusal
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Create refusal response events for input safety violations."""
# Create the refusal content part explicitly with the correct structure
refusal_response = OpenAIResponseObject(
id=self.response_id,
created_at=self.created_at,
model=self.ctx.model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
"""Check accumulated streaming text content against shields. Returns violation message if blocked."""
if not self.shield_ids or not accumulated_text:
@ -237,8 +223,7 @@ class StreamingResponseOrchestrator:
input_refusal = await self._check_input_safety(self.ctx.messages)
if input_refusal:
# Return refusal response immediately
async for refusal_event in self._create_input_refusal_response_events(input_refusal):
yield refusal_event
yield await self._create_refusal_response(input_refusal.refusal)
return
async for stream_event in self._process_tools(output_messages):

View file

@ -38,6 +38,7 @@ def responses_impl_with_conversations(
mock_responses_store,
mock_vector_io_api,
mock_conversations_api,
mock_safety_api,
):
"""Create OpenAIResponsesImpl instance with conversations API."""
return OpenAIResponsesImpl(
@ -47,6 +48,7 @@ def responses_impl_with_conversations(
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
safety_api=mock_safety_api,
)