fix tests

This commit is contained in:
Swapna Lekkala 2025-10-10 09:38:33 -07:00
parent b5c951fa4b
commit 7b5dd54b7a
4 changed files with 13 additions and 8 deletions

View file

@ -345,7 +345,7 @@ class OpenAIResponsesImpl:
return return
# Structured outputs # Structured outputs
response_format = convert_response_text_to_chat_response_format(text) response_format = await convert_response_text_to_chat_response_format(text)
ctx = ChatCompletionContext( ctx = ChatCompletionContext(
model=model, model=model,

View file

@ -13,8 +13,8 @@ from llama_stack.apis.agents.openai_responses import (
ApprovalFilter, ApprovalFilter,
MCPListToolsTool, MCPListToolsTool,
OpenAIResponseContentPartOutputText, OpenAIResponseContentPartOutputText,
OpenAIResponseError,
OpenAIResponseContentPartRefusal, OpenAIResponseContentPartRefusal,
OpenAIResponseError,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseInputToolMCP, OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest, OpenAIResponseMCPApprovalRequest,
@ -226,6 +226,11 @@ class StreamingResponseOrchestrator:
completion_result_data = stream_event_or_result completion_result_data = stream_event_or_result
else: else:
yield stream_event_or_result yield stream_event_or_result
# If violation detected, skip the rest of processing since we already sent refusal
if self.violation_detected:
return
if not completion_result_data: if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data") raise ValueError("Streaming chunk processor failed to return completion data")
last_completion_result = completion_result_data last_completion_result = completion_result_data

View file

@ -178,7 +178,7 @@ async def convert_response_input_to_chat_messages(
pass pass
else: else:
content = await convert_response_content_to_chat_content(input_item.content) content = await convert_response_content_to_chat_content(input_item.content)
message_type = get_message_type_by_role(input_item.role) message_type = await get_message_type_by_role(input_item.role)
if message_type is None: if message_type is None:
raise ValueError( raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"

View file

@ -84,14 +84,14 @@ def test_extract_shield_ids_empty_list(responses_impl):
assert result == [] assert result == []
def test_extract_shield_ids_unknown_format(responses_impl, caplog): def test_extract_shield_ids_unknown_format(responses_impl):
"""Test extraction with unknown shield format logs warning.""" """Test extraction with unknown shield format raises ValueError."""
# Create an object that's neither string nor ResponseShieldSpec # Create an object that's neither string nor ResponseShieldSpec
unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec
shields = ["valid-shield", unknown_object, "another-shield"] shields = ["valid-shield", unknown_object, "another-shield"]
result = extract_shield_ids(shields)
assert result == ["valid-shield", "another-shield"] with pytest.raises(ValueError, match="Unknown shield format.*expected str or ResponseShieldSpec"):
assert "Unknown shield format" in caplog.text extract_shield_ids(shields)
def test_extract_text_content_string(responses_impl): def test_extract_text_content_string(responses_impl):