mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
clean and fix tests
This commit is contained in:
parent
ad4362e48d
commit
171fb7101d
3 changed files with 3 additions and 44 deletions
|
|
@ -15,15 +15,12 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
|
|
@ -300,30 +297,6 @@ class OpenAIResponsesImpl:
|
|||
raise ValueError("The response stream never reached a terminal state")
|
||||
return final_response
|
||||
|
||||
async def _create_refusal_response_events(
|
||||
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create and yield refusal response events following the established streaming pattern."""
|
||||
# Create initial response and yield created event
|
||||
initial_response = OpenAIResponseObject(
|
||||
id=response_id,
|
||||
created_at=created_at,
|
||||
model=model,
|
||||
status="in_progress",
|
||||
output=[],
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# Create completed refusal response using OpenAIResponseContentPartRefusal
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=response_id,
|
||||
created_at=created_at,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
|
|
@ -375,7 +348,6 @@ class OpenAIResponsesImpl:
|
|||
shield_ids=shield_ids,
|
||||
)
|
||||
|
||||
# Output safety validation hook - delegated to streaming orchestrator for real-time validation
|
||||
# Stream the response
|
||||
final_response = None
|
||||
failed_response = None
|
||||
|
|
|
|||
|
|
@ -147,20 +147,6 @@ class StreamingResponseOrchestrator:
|
|||
refusal=e.violation.user_message or "Content blocked by safety shields"
|
||||
)
|
||||
|
||||
async def _create_input_refusal_response_events(
|
||||
self, refusal_content: OpenAIResponseContentPartRefusal
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create refusal response events for input safety violations."""
|
||||
# Create the refusal content part explicitly with the correct structure
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=self.response_id,
|
||||
created_at=self.created_at,
|
||||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
|
||||
"""Check accumulated streaming text content against shields. Returns violation message if blocked."""
|
||||
if not self.shield_ids or not accumulated_text:
|
||||
|
|
@ -237,8 +223,7 @@ class StreamingResponseOrchestrator:
|
|||
input_refusal = await self._check_input_safety(self.ctx.messages)
|
||||
if input_refusal:
|
||||
# Return refusal response immediately
|
||||
async for refusal_event in self._create_input_refusal_response_events(input_refusal):
|
||||
yield refusal_event
|
||||
yield await self._create_refusal_response(input_refusal.refusal)
|
||||
return
|
||||
|
||||
async for stream_event in self._process_tools(output_messages):
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ def responses_impl_with_conversations(
|
|||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_conversations_api,
|
||||
mock_safety_api,
|
||||
):
|
||||
"""Create OpenAIResponsesImpl instance with conversations API."""
|
||||
return OpenAIResponsesImpl(
|
||||
|
|
@ -47,6 +48,7 @@ def responses_impl_with_conversations(
|
|||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
safety_api=mock_safety_api,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue