mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
feat(responses): improve streaming for function calls
This commit is contained in:
parent
6358d0a478
commit
674478c851
2 changed files with 236 additions and 28 deletions
|
@ -33,6 +33,10 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
OpenAIResponseObjectStreamResponseCompleted,
|
OpenAIResponseObjectStreamResponseCompleted,
|
||||||
OpenAIResponseObjectStreamResponseCreated,
|
OpenAIResponseObjectStreamResponseCreated,
|
||||||
|
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
||||||
|
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||||
|
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||||
|
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||||
OpenAIResponseOutput,
|
OpenAIResponseOutput,
|
||||||
OpenAIResponseOutputMessageContent,
|
OpenAIResponseOutputMessageContent,
|
||||||
|
@ -73,7 +77,9 @@ from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
convert_tooldef_to_openai_tool,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="openai_responses")
|
logger = get_logger(name=__name__, category="openai_responses")
|
||||||
|
@ -82,7 +88,7 @@ OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||||
|
|
||||||
|
|
||||||
async def _convert_response_content_to_chat_content(
|
async def _convert_response_content_to_chat_content(
|
||||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
|
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||||
"""
|
"""
|
||||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||||
|
@ -150,7 +156,9 @@ async def _convert_response_input_to_chat_messages(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
async def _convert_chat_choice_to_response_message(
|
||||||
|
choice: OpenAIChoice,
|
||||||
|
) -> OpenAIResponseMessage:
|
||||||
"""
|
"""
|
||||||
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
|
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
|
||||||
"""
|
"""
|
||||||
|
@ -172,7 +180,9 @@ async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> Open
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam:
|
async def _convert_response_text_to_chat_response_format(
|
||||||
|
text: OpenAIResponseText,
|
||||||
|
) -> OpenAIResponseFormatParam:
|
||||||
"""
|
"""
|
||||||
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||||
"""
|
"""
|
||||||
|
@ -228,7 +238,9 @@ class OpenAIResponsesImpl:
|
||||||
self.vector_io_api = vector_io_api
|
self.vector_io_api = vector_io_api
|
||||||
|
|
||||||
async def _prepend_previous_response(
|
async def _prepend_previous_response(
|
||||||
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
self,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
previous_response_id: str | None = None,
|
||||||
):
|
):
|
||||||
if previous_response_id:
|
if previous_response_id:
|
||||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||||
|
@ -446,6 +458,8 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
# Create a placeholder message item for delta events
|
# Create a placeholder message item for delta events
|
||||||
message_item_id = f"msg_{uuid.uuid4()}"
|
message_item_id = f"msg_{uuid.uuid4()}"
|
||||||
|
# Track tool call items for streaming events
|
||||||
|
tool_call_item_ids: dict[int, str] = {}
|
||||||
|
|
||||||
async for chunk in completion_result:
|
async for chunk in completion_result:
|
||||||
chat_response_id = chunk.id
|
chat_response_id = chunk.id
|
||||||
|
@ -472,18 +486,61 @@ class OpenAIResponsesImpl:
|
||||||
if chunk_choice.delta.tool_calls:
|
if chunk_choice.delta.tool_calls:
|
||||||
for tool_call in chunk_choice.delta.tool_calls:
|
for tool_call in chunk_choice.delta.tool_calls:
|
||||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||||
if response_tool_call:
|
# Create new tool call entry if this is the first chunk for this index
|
||||||
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
|
if response_tool_call is None:
|
||||||
if tool_call.function.arguments:
|
|
||||||
# Guard against an initial None argument before we concatenate
|
|
||||||
response_tool_call.function.arguments = (
|
|
||||||
response_tool_call.function.arguments or ""
|
|
||||||
) + tool_call.function.arguments
|
|
||||||
else:
|
|
||||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||||
tool_call_dict.pop("type", None)
|
tool_call_dict.pop("type", None)
|
||||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||||
|
|
||||||
|
# Create item ID for this tool call for streaming events
|
||||||
|
tool_call_item_id = f"fc_{uuid.uuid4()}"
|
||||||
|
tool_call_item_ids[tool_call.index] = tool_call_item_id
|
||||||
|
|
||||||
|
# Emit output_item.added event for the new function call
|
||||||
|
sequence_number += 1
|
||||||
|
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
arguments="", # Will be filled incrementally via delta events
|
||||||
|
call_id=tool_call.id or "",
|
||||||
|
name=tool_call.function.name if tool_call.function else "",
|
||||||
|
id=tool_call_item_id,
|
||||||
|
status="in_progress",
|
||||||
|
)
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||||
|
response_id=response_id,
|
||||||
|
item=function_call_item,
|
||||||
|
output_index=len(output_messages),
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream function call arguments as they arrive
|
||||||
|
if tool_call.function and tool_call.function.arguments:
|
||||||
|
tool_call_item_id = tool_call_item_ids[tool_call.index]
|
||||||
|
sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
|
||||||
|
delta=tool_call.function.arguments,
|
||||||
|
item_id=tool_call_item_id,
|
||||||
|
output_index=len(output_messages),
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Accumulate arguments for final response
|
||||||
|
response_tool_call.function.arguments = (
|
||||||
|
response_tool_call.function.arguments or ""
|
||||||
|
) + tool_call.function.arguments
|
||||||
|
|
||||||
|
# Emit function_call_arguments.done events for completed tool calls
|
||||||
|
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||||
|
if tool_call_index in tool_call_item_ids:
|
||||||
|
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||||
|
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
|
||||||
|
sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(
|
||||||
|
arguments=final_arguments,
|
||||||
|
item_id=tool_call_item_id,
|
||||||
|
output_index=len(output_messages),
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
# Convert collected chunks to complete response
|
# Convert collected chunks to complete response
|
||||||
if chat_response_tool_calls:
|
if chat_response_tool_calls:
|
||||||
|
@ -532,18 +589,56 @@ class OpenAIResponsesImpl:
|
||||||
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
|
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
|
||||||
if tool_call_log:
|
if tool_call_log:
|
||||||
output_messages.append(tool_call_log)
|
output_messages.append(tool_call_log)
|
||||||
|
|
||||||
|
# Emit output_item.done event for completed non-function tool call
|
||||||
|
# Find the item_id for this tool call
|
||||||
|
matching_item_id = None
|
||||||
|
for index, item_id in tool_call_item_ids.items():
|
||||||
|
response_tool_call = chat_response_tool_calls.get(index)
|
||||||
|
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||||
|
matching_item_id = item_id
|
||||||
|
break
|
||||||
|
|
||||||
|
if matching_item_id:
|
||||||
|
sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||||
|
response_id=response_id,
|
||||||
|
item=tool_call_log,
|
||||||
|
output_index=len(output_messages) - 1,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
if tool_response_message:
|
if tool_response_message:
|
||||||
next_turn_messages.append(tool_response_message)
|
next_turn_messages.append(tool_response_message)
|
||||||
|
|
||||||
for tool_call in function_tool_calls:
|
for tool_call in function_tool_calls:
|
||||||
output_messages.append(
|
# Find the item_id for this tool call from our tracking dictionary
|
||||||
OpenAIResponseOutputMessageFunctionToolCall(
|
matching_item_id = None
|
||||||
arguments=tool_call.function.arguments or "",
|
for index, item_id in tool_call_item_ids.items():
|
||||||
call_id=tool_call.id,
|
response_tool_call = chat_response_tool_calls.get(index)
|
||||||
name=tool_call.function.name or "",
|
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||||
id=f"fc_{uuid.uuid4()}",
|
matching_item_id = item_id
|
||||||
status="completed",
|
break
|
||||||
)
|
|
||||||
|
# Use existing item_id or create new one if not found
|
||||||
|
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
arguments=tool_call.function.arguments or "",
|
||||||
|
call_id=tool_call.id,
|
||||||
|
name=tool_call.function.name or "",
|
||||||
|
id=final_item_id,
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
output_messages.append(function_call_item)
|
||||||
|
|
||||||
|
# Emit output_item.done event for completed function call
|
||||||
|
sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||||
|
response_id=response_id,
|
||||||
|
item=function_call_item,
|
||||||
|
output_index=len(output_messages) - 1,
|
||||||
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not function_tool_calls and not non_function_tool_calls:
|
if not function_tool_calls and not non_function_tool_calls:
|
||||||
|
@ -779,7 +874,8 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
elif function.name == "knowledge_search":
|
elif function.name == "knowledge_search":
|
||||||
response_file_search_tool = next(
|
response_file_search_tool = next(
|
||||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None
|
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
if response_file_search_tool:
|
if response_file_search_tool:
|
||||||
# Use vector_stores.search API instead of knowledge_search tool
|
# Use vector_stores.search API instead of knowledge_search tool
|
||||||
|
@ -798,7 +894,9 @@ class OpenAIResponsesImpl:
|
||||||
error_exc = e
|
error_exc = e
|
||||||
|
|
||||||
if function.name in ctx.mcp_tool_to_server:
|
if function.name in ctx.mcp_tool_to_server:
|
||||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
|
)
|
||||||
|
|
||||||
message = OpenAIResponseOutputMessageMCPCall(
|
message = OpenAIResponseOutputMessageMCPCall(
|
||||||
id=tool_call_id,
|
id=tool_call_id,
|
||||||
|
@ -850,7 +948,10 @@ class OpenAIResponsesImpl:
|
||||||
if isinstance(result.content, str):
|
if isinstance(result.content, str):
|
||||||
content = result.content
|
content = result.content
|
||||||
elif isinstance(result.content, list):
|
elif isinstance(result.content, list):
|
||||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
from llama_stack.apis.common.content_types import (
|
||||||
|
ImageContentItem,
|
||||||
|
TextContentItem,
|
||||||
|
)
|
||||||
|
|
||||||
content = []
|
content = []
|
||||||
for item in result.content:
|
for item in result.content:
|
||||||
|
|
|
@ -384,12 +384,18 @@ def test_response_non_streaming_mcp_tool(request, compat_client, text_model_id,
|
||||||
assert list_tools.type == "mcp_list_tools"
|
assert list_tools.type == "mcp_list_tools"
|
||||||
assert list_tools.server_label == "localmcp"
|
assert list_tools.server_label == "localmcp"
|
||||||
assert len(list_tools.tools) == 2
|
assert len(list_tools.tools) == 2
|
||||||
assert {t.name for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"}
|
assert {t.name for t in list_tools.tools} == {
|
||||||
|
"get_boiling_point",
|
||||||
|
"greet_everyone",
|
||||||
|
}
|
||||||
|
|
||||||
call = response.output[1]
|
call = response.output[1]
|
||||||
assert call.type == "mcp_call"
|
assert call.type == "mcp_call"
|
||||||
assert call.name == "get_boiling_point"
|
assert call.name == "get_boiling_point"
|
||||||
assert json.loads(call.arguments) == {"liquid_name": "myawesomeliquid", "celsius": True}
|
assert json.loads(call.arguments) == {
|
||||||
|
"liquid_name": "myawesomeliquid",
|
||||||
|
"celsius": True,
|
||||||
|
}
|
||||||
assert call.error is None
|
assert call.error is None
|
||||||
assert "-100" in call.output
|
assert "-100" in call.output
|
||||||
|
|
||||||
|
@ -581,6 +587,105 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_
|
||||||
f"Last chunk should be response.completed, got {chunks[-1].type}"
|
f"Last chunk should be response.completed, got {chunks[-1].type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Verify tool call streaming events are present
|
||||||
|
chunk_types = [chunk.type for chunk in chunks]
|
||||||
|
|
||||||
|
# Should have function call arguments delta events for tool calls
|
||||||
|
delta_events = [chunk for chunk in chunks if chunk.type == "response.function_call_arguments.delta"]
|
||||||
|
done_events = [chunk for chunk in chunks if chunk.type == "response.function_call_arguments.done"]
|
||||||
|
|
||||||
|
# Should have output item events for tool calls
|
||||||
|
item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"]
|
||||||
|
item_done_events = [chunk for chunk in chunks if chunk.type == "response.output_item.done"]
|
||||||
|
|
||||||
|
# Verify we have substantial streaming activity (not just batch events)
|
||||||
|
assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks"
|
||||||
|
|
||||||
|
# Since this test involves MCP tool calls, we should see streaming events
|
||||||
|
assert len(delta_events) > 0, f"Expected function_call_arguments.delta events, got chunk types: {chunk_types}"
|
||||||
|
assert len(done_events) > 0, f"Expected function_call_arguments.done events, got chunk types: {chunk_types}"
|
||||||
|
|
||||||
|
# Should have output item events for function calls
|
||||||
|
assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}"
|
||||||
|
assert len(item_done_events) > 0, f"Expected response.output_item.done events, got chunk types: {chunk_types}"
|
||||||
|
|
||||||
|
# Verify delta events have proper structure
|
||||||
|
for delta_event in delta_events:
|
||||||
|
assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field"
|
||||||
|
assert hasattr(delta_event, "item_id"), "Delta event should have 'item_id' field"
|
||||||
|
assert hasattr(delta_event, "sequence_number"), "Delta event should have 'sequence_number' field"
|
||||||
|
assert delta_event.delta, "Delta should not be empty"
|
||||||
|
|
||||||
|
# Verify done events have proper structure
|
||||||
|
for done_event in done_events:
|
||||||
|
assert hasattr(done_event, "arguments"), "Done event should have 'arguments' field"
|
||||||
|
assert hasattr(done_event, "item_id"), "Done event should have 'item_id' field"
|
||||||
|
assert done_event.arguments, "Final arguments should not be empty"
|
||||||
|
|
||||||
|
# Verify output item added events have proper structure
|
||||||
|
for added_event in item_added_events:
|
||||||
|
assert hasattr(added_event, "item"), "Added event should have 'item' field"
|
||||||
|
assert hasattr(added_event, "output_index"), "Added event should have 'output_index' field"
|
||||||
|
assert hasattr(added_event, "sequence_number"), "Added event should have 'sequence_number' field"
|
||||||
|
assert hasattr(added_event, "response_id"), "Added event should have 'response_id' field"
|
||||||
|
assert added_event.item.type in ["function_call", "mcp_call"], "Added item should be a tool call"
|
||||||
|
assert added_event.item.status == "in_progress", "Added item should be in progress"
|
||||||
|
assert added_event.response_id, "Response ID should not be empty"
|
||||||
|
assert isinstance(added_event.output_index, int), "Output index should be integer"
|
||||||
|
assert added_event.output_index >= 0, "Output index should be non-negative"
|
||||||
|
|
||||||
|
# Verify output item done events have proper structure
|
||||||
|
for done_event in item_done_events:
|
||||||
|
assert hasattr(done_event, "item"), "Done event should have 'item' field"
|
||||||
|
assert hasattr(done_event, "output_index"), "Done event should have 'output_index' field"
|
||||||
|
assert hasattr(done_event, "sequence_number"), "Done event should have 'sequence_number' field"
|
||||||
|
assert hasattr(done_event, "response_id"), "Done event should have 'response_id' field"
|
||||||
|
assert done_event.item.type in ["function_call", "mcp_call"], "Done item should be a tool call"
|
||||||
|
# Note: MCP calls don't have a status field, only function calls do
|
||||||
|
if done_event.item.type == "function_call":
|
||||||
|
assert done_event.item.status == "completed", "Function call should be completed"
|
||||||
|
assert done_event.response_id, "Response ID should not be empty"
|
||||||
|
assert isinstance(done_event.output_index, int), "Output index should be integer"
|
||||||
|
assert done_event.output_index >= 0, "Output index should be non-negative"
|
||||||
|
|
||||||
|
# Group function call argument events by item_id (these should have proper tracking)
|
||||||
|
function_call_events_by_item_id = {}
|
||||||
|
for chunk in chunks:
|
||||||
|
if hasattr(chunk, "item_id") and chunk.type in [
|
||||||
|
"response.function_call_arguments.delta",
|
||||||
|
"response.function_call_arguments.done",
|
||||||
|
]:
|
||||||
|
item_id = chunk.item_id
|
||||||
|
if item_id not in function_call_events_by_item_id:
|
||||||
|
function_call_events_by_item_id[item_id] = []
|
||||||
|
function_call_events_by_item_id[item_id].append(chunk)
|
||||||
|
|
||||||
|
for item_id, related_events in function_call_events_by_item_id.items():
|
||||||
|
# Should have at least one delta and one done event for a complete function call
|
||||||
|
delta_events = [e for e in related_events if e.type == "response.function_call_arguments.delta"]
|
||||||
|
done_events = [e for e in related_events if e.type == "response.function_call_arguments.done"]
|
||||||
|
|
||||||
|
assert len(delta_events) > 0, f"Item {item_id} should have at least one delta event"
|
||||||
|
assert len(done_events) == 1, f"Item {item_id} should have exactly one done event"
|
||||||
|
|
||||||
|
# Verify all events have the same item_id
|
||||||
|
for event in related_events:
|
||||||
|
assert event.item_id == item_id, f"Event should have consistent item_id {item_id}, got {event.item_id}"
|
||||||
|
|
||||||
|
# Basic pairing check: each output_item.added should be followed by some activity
|
||||||
|
# (but we can't enforce strict 1:1 pairing due to the complexity of multi-turn scenarios)
|
||||||
|
assert len(item_added_events) > 0, "Should have at least one output_item.added event"
|
||||||
|
|
||||||
|
# Verify response_id consistency across all events
|
||||||
|
response_ids = set()
|
||||||
|
for chunk in chunks:
|
||||||
|
if hasattr(chunk, "response_id"):
|
||||||
|
response_ids.add(chunk.response_id)
|
||||||
|
elif hasattr(chunk, "response") and hasattr(chunk.response, "id"):
|
||||||
|
response_ids.add(chunk.response.id)
|
||||||
|
|
||||||
|
assert len(response_ids) == 1, f"All events should reference the same response_id, found: {response_ids}"
|
||||||
|
|
||||||
# Get the final response from the last chunk
|
# Get the final response from the last chunk
|
||||||
final_chunk = chunks[-1]
|
final_chunk = chunks[-1]
|
||||||
if hasattr(final_chunk, "response"):
|
if hasattr(final_chunk, "response"):
|
||||||
|
@ -722,7 +827,9 @@ def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_fact
|
||||||
|
|
||||||
# Attach file to vector store with attributes
|
# Attach file to vector store with attributes
|
||||||
file_attach_response = compat_client.vector_stores.files.create(
|
file_attach_response = compat_client.vector_stores.files.create(
|
||||||
vector_store_id=vector_store.id, file_id=file_response.id, attributes=file_data["attributes"]
|
vector_store_id=vector_store.id,
|
||||||
|
file_id=file_response.id,
|
||||||
|
attributes=file_data["attributes"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for attachment
|
# Wait for attachment
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue