mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
clean and fix tests
This commit is contained in:
parent
ad4362e48d
commit
171fb7101d
3 changed files with 3 additions and 44 deletions
|
|
@ -15,15 +15,12 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
ListOpenAIResponseInputItem,
|
ListOpenAIResponseInputItem,
|
||||||
ListOpenAIResponseObject,
|
ListOpenAIResponseObject,
|
||||||
OpenAIDeleteResponseObject,
|
OpenAIDeleteResponseObject,
|
||||||
OpenAIResponseContentPartRefusal,
|
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputMessageContentText,
|
OpenAIResponseInputMessageContentText,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseMessage,
|
OpenAIResponseMessage,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
OpenAIResponseObjectStreamResponseCompleted,
|
|
||||||
OpenAIResponseObjectStreamResponseCreated,
|
|
||||||
OpenAIResponseText,
|
OpenAIResponseText,
|
||||||
OpenAIResponseTextFormat,
|
OpenAIResponseTextFormat,
|
||||||
)
|
)
|
||||||
|
|
@ -300,30 +297,6 @@ class OpenAIResponsesImpl:
|
||||||
raise ValueError("The response stream never reached a terminal state")
|
raise ValueError("The response stream never reached a terminal state")
|
||||||
return final_response
|
return final_response
|
||||||
|
|
||||||
async def _create_refusal_response_events(
|
|
||||||
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
|
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
|
||||||
"""Create and yield refusal response events following the established streaming pattern."""
|
|
||||||
# Create initial response and yield created event
|
|
||||||
initial_response = OpenAIResponseObject(
|
|
||||||
id=response_id,
|
|
||||||
created_at=created_at,
|
|
||||||
model=model,
|
|
||||||
status="in_progress",
|
|
||||||
output=[],
|
|
||||||
)
|
|
||||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
|
||||||
|
|
||||||
# Create completed refusal response using OpenAIResponseContentPartRefusal
|
|
||||||
refusal_response = OpenAIResponseObject(
|
|
||||||
id=response_id,
|
|
||||||
created_at=created_at,
|
|
||||||
model=model,
|
|
||||||
status="completed",
|
|
||||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
|
||||||
)
|
|
||||||
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
|
||||||
|
|
||||||
async def _create_streaming_response(
|
async def _create_streaming_response(
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
|
|
@ -375,7 +348,6 @@ class OpenAIResponsesImpl:
|
||||||
shield_ids=shield_ids,
|
shield_ids=shield_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output safety validation hook - delegated to streaming orchestrator for real-time validation
|
|
||||||
# Stream the response
|
# Stream the response
|
||||||
final_response = None
|
final_response = None
|
||||||
failed_response = None
|
failed_response = None
|
||||||
|
|
|
||||||
|
|
@ -147,20 +147,6 @@ class StreamingResponseOrchestrator:
|
||||||
refusal=e.violation.user_message or "Content blocked by safety shields"
|
refusal=e.violation.user_message or "Content blocked by safety shields"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_input_refusal_response_events(
|
|
||||||
self, refusal_content: OpenAIResponseContentPartRefusal
|
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
|
||||||
"""Create refusal response events for input safety violations."""
|
|
||||||
# Create the refusal content part explicitly with the correct structure
|
|
||||||
refusal_response = OpenAIResponseObject(
|
|
||||||
id=self.response_id,
|
|
||||||
created_at=self.created_at,
|
|
||||||
model=self.ctx.model,
|
|
||||||
status="completed",
|
|
||||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
|
||||||
)
|
|
||||||
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
|
||||||
|
|
||||||
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
|
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
|
||||||
"""Check accumulated streaming text content against shields. Returns violation message if blocked."""
|
"""Check accumulated streaming text content against shields. Returns violation message if blocked."""
|
||||||
if not self.shield_ids or not accumulated_text:
|
if not self.shield_ids or not accumulated_text:
|
||||||
|
|
@ -237,8 +223,7 @@ class StreamingResponseOrchestrator:
|
||||||
input_refusal = await self._check_input_safety(self.ctx.messages)
|
input_refusal = await self._check_input_safety(self.ctx.messages)
|
||||||
if input_refusal:
|
if input_refusal:
|
||||||
# Return refusal response immediately
|
# Return refusal response immediately
|
||||||
async for refusal_event in self._create_input_refusal_response_events(input_refusal):
|
yield await self._create_refusal_response(input_refusal.refusal)
|
||||||
yield refusal_event
|
|
||||||
return
|
return
|
||||||
|
|
||||||
async for stream_event in self._process_tools(output_messages):
|
async for stream_event in self._process_tools(output_messages):
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ def responses_impl_with_conversations(
|
||||||
mock_responses_store,
|
mock_responses_store,
|
||||||
mock_vector_io_api,
|
mock_vector_io_api,
|
||||||
mock_conversations_api,
|
mock_conversations_api,
|
||||||
|
mock_safety_api,
|
||||||
):
|
):
|
||||||
"""Create OpenAIResponsesImpl instance with conversations API."""
|
"""Create OpenAIResponsesImpl instance with conversations API."""
|
||||||
return OpenAIResponsesImpl(
|
return OpenAIResponsesImpl(
|
||||||
|
|
@ -47,6 +48,7 @@ def responses_impl_with_conversations(
|
||||||
responses_store=mock_responses_store,
|
responses_store=mock_responses_store,
|
||||||
vector_io_api=mock_vector_io_api,
|
vector_io_api=mock_vector_io_api,
|
||||||
conversations_api=mock_conversations_api,
|
conversations_api=mock_conversations_api,
|
||||||
|
safety_api=mock_safety_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue