refactor: unify stream and non-stream impls for responses (#2388)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Integration Tests / test-matrix (http, datasets) (push) Failing after 9s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, inspect) (push) Failing after 8s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, providers) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Test External Providers / test-external-providers (venv) (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 11s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 7s
Unit Tests / unit-tests (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 30s
Pre-commit / pre-commit (push) Successful in 1m18s

The non-streaming version is just a small layer on top of the streaming
version - just pluck off the final `response.completed` event and return
that as the response!

This PR also includes a couple other changes which I ended up making
while working on it on a flight:
- changes to `ollama` so it does not pull embedding models
unconditionally
- a small fix to library client to make the stream and non-stream cases
a bit more symmetric
This commit is contained in:
Ashwin Bharambe 2025-06-05 17:48:09 +02:00 committed by GitHub
parent ef885d2147
commit 3251b44d8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 166 additions and 315 deletions

View file

@ -149,12 +149,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
logger.info(f"Removed handler {handler.__class__.__name__} from root logger") logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
# NOTE: We are using AsyncLlamaStackClient under the hood
# A new event loop is needed to convert the AsyncStream
# from async client into SyncStream return type for streaming
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if kwargs.get("stream"): if kwargs.get("stream"):
# NOTE: We are using AsyncLlamaStackClient under the hood
# A new event loop is needed to convert the AsyncStream
# from async client into SyncStream return type for streaming
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
def sync_generator(): def sync_generator():
try: try:
@ -172,7 +173,14 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
return sync_generator() return sync_generator()
else: else:
return asyncio.run(self.async_client.request(*args, **kwargs)) try:
result = loop.run_until_complete(self.async_client.request(*args, **kwargs))
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return result
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):

View file

@ -8,7 +8,7 @@ import json
import time import time
import uuid import uuid
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any, cast from typing import Any
from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel from pydantic import BaseModel
@ -200,7 +200,6 @@ class ChatCompletionContext(BaseModel):
messages: list[OpenAIMessageParam] messages: list[OpenAIMessageParam]
tools: list[ChatCompletionToolParam] | None = None tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
stream: bool
temperature: float | None temperature: float | None
response_format: OpenAIResponseFormatParam response_format: OpenAIResponseFormatParam
@ -281,49 +280,6 @@ class OpenAIResponsesImpl:
""" """
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
def _is_function_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False
async def _process_response_choices(
self,
chat_response: OpenAIChatCompletion,
ctx: ChatCompletionContext,
tools: list[OpenAIResponseInputTool] | None,
) -> list[OpenAIResponseOutput]:
"""Handle tool execution and response message creation."""
output_messages: list[OpenAIResponseOutput] = []
# Execute tool calls if any
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
if self._is_function_tool_call(choice.message.tool_calls[0], tools):
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
else:
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
output_messages.extend(tool_messages)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
return output_messages
async def _store_response( async def _store_response(
self, self,
response: OpenAIResponseObject, response: OpenAIResponseObject,
@ -370,9 +326,48 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
): ):
stream = False if stream is None else stream stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
stream_gen = self._create_streaming_response(
input=input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
temperature=temperature,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
if response is None:
raise ValueError("The response stream never completed")
return response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
output_messages: list[OpenAIResponseOutput] = [] output_messages: list[OpenAIResponseOutput] = []
# Input preprocessing # Input preprocessing
@ -383,7 +378,7 @@ class OpenAIResponsesImpl:
# Structured outputs # Structured outputs
response_format = await _convert_response_text_to_chat_response_format(text) response_format = await _convert_response_text_to_chat_response_format(text)
# Tool setup # Tool setup, TODO: refactor this slightly since this can also yield events
chat_tools, mcp_tool_to_server, mcp_list_message = ( chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None) await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
) )
@ -395,136 +390,10 @@ class OpenAIResponsesImpl:
messages=messages, messages=messages,
tools=chat_tools, tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server, mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature, temperature=temperature,
response_format=response_format, response_format=response_format,
) )
# Fork to streaming vs non-streaming - let each handle ALL inference rounds
if stream:
return self._create_streaming_response(
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
else:
return await self._create_non_streaming_response(
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
async def _create_non_streaming_response(
self,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
text: OpenAIResponseText,
tools: list[OpenAIResponseInputTool] | None,
max_infer_iters: int,
) -> OpenAIResponseObject:
n_iter = 0
messages = ctx.messages.copy()
while True:
# Do inference (including the first one)
inference_result = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=messages,
tools=ctx.tools,
stream=False,
temperature=ctx.temperature,
response_format=ctx.response_format,
)
completion = OpenAIChatCompletion(**inference_result.model_dump())
# Separate function vs non-function tool calls
function_tool_calls = []
non_function_tool_calls = []
for choice in completion.choices:
if choice.message.tool_calls and tools:
for tool_call in choice.message.tool_calls:
if self._is_function_tool_call(tool_call, tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
# Process response choices based on tool call types
if function_tool_calls:
# For function tool calls, use existing logic and return immediately
current_output_messages = await self._process_response_choices(
chat_response=completion,
ctx=ctx,
tools=tools,
)
output_messages.extend(current_output_messages)
break
elif non_function_tool_calls:
# For non-function tool calls, execute them and continue loop
for choice in completion.choices:
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
output_messages.extend(tool_outputs)
# Add assistant message and tool responses to messages for next iteration
messages.append(choice.message)
messages.extend(tool_response_messages)
n_iter += 1
if n_iter >= max_infer_iters:
break
# Continue with next iteration of the loop
continue
else:
# No tool calls - convert response to message and we're done
for choice in completion.choices:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
break
response = OpenAIResponseObject(
created_at=completion.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
status="completed",
output=output_messages,
text=text,
)
logger.debug(f"OpenAI Responses response: {response}")
# Store response if requested
if store:
await self._store_response(
response=response,
input=input,
)
return response
async def _create_streaming_response(
self,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
text: OpenAIResponseText,
tools: list[OpenAIResponseInputTool] | None,
max_infer_iters: int | None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Create initial response and emit response.created immediately # Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}" response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time()) created_at = int(time.time())
@ -539,15 +408,13 @@ class OpenAIResponsesImpl:
text=text, text=text,
) )
# Emit response.created immediately
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# Implement tool execution loop for streaming - handle ALL inference rounds including the first
n_iter = 0 n_iter = 0
messages = ctx.messages.copy() messages = ctx.messages.copy()
while True: while True:
current_inference_result = await self.inference_api.openai_chat_completion( completion_result = await self.inference_api.openai_chat_completion(
model=ctx.model, model=ctx.model,
messages=messages, messages=messages,
tools=ctx.tools, tools=ctx.tools,
@ -568,7 +435,7 @@ class OpenAIResponsesImpl:
# Create a placeholder message item for delta events # Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}" message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in current_inference_result: async for chunk in completion_result:
chat_response_id = chunk.id chat_response_id = chunk.id
chunk_created = chunk.created chunk_created = chunk.created
chunk_model = chunk.model chunk_model = chunk.model
@ -628,50 +495,55 @@ class OpenAIResponsesImpl:
model=chunk_model, model=chunk_model,
) )
# Separate function vs non-function tool calls
function_tool_calls = [] function_tool_calls = []
non_function_tool_calls = [] non_function_tool_calls = []
next_turn_messages = messages.copy()
for choice in current_response.choices: for choice in current_response.choices:
next_turn_messages.append(choice.message)
if choice.message.tool_calls and tools: if choice.message.tool_calls and tools:
for tool_call in choice.message.tool_calls: for tool_call in choice.message.tool_calls:
if self._is_function_tool_call(tool_call, tools): if _is_function_tool_call(tool_call, tools):
function_tool_calls.append(tool_call) function_tool_calls.append(tool_call)
else: else:
non_function_tool_calls.append(tool_call) non_function_tool_calls.append(tool_call)
else:
# Process response choices based on tool call types
if function_tool_calls:
# For function tool calls, use existing logic and break
current_output_messages = await self._process_response_choices(
chat_response=current_response,
ctx=ctx,
tools=tools,
)
output_messages.extend(current_output_messages)
break
elif non_function_tool_calls:
# For non-function tool calls, execute them and continue loop
for choice in current_response.choices:
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
output_messages.extend(tool_outputs)
# Add assistant message and tool responses to messages for next iteration
messages.append(choice.message)
messages.extend(tool_response_messages)
n_iter += 1
if n_iter >= (max_infer_iters or 10):
break
# Continue with next iteration of the loop
continue
else:
# No tool calls - convert response to message and we're done
for choice in current_response.choices:
output_messages.append(await _convert_chat_choice_to_response_message(choice)) output_messages.append(await _convert_chat_choice_to_response_message(choice))
# execute non-function tool calls
for tool_call in non_function_tool_calls:
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if tool_response_message:
next_turn_messages.append(tool_response_message)
for tool_call in function_tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
if not function_tool_calls and not non_function_tool_calls:
break break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
break
messages = next_turn_messages
# Create final response # Create final response
final_response = OpenAIResponseObject( final_response = OpenAIResponseObject(
created_at=created_at, created_at=created_at,
@ -683,15 +555,15 @@ class OpenAIResponsesImpl:
output=output_messages, output=output_messages,
) )
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
if store: if store:
await self._store_response( await self._store_response(
response=final_response, response=final_response,
input=input, input=input,
) )
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
async def _convert_response_tools_to_chat_tools( async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool] self, tools: list[OpenAIResponseInputTool]
) -> tuple[ ) -> tuple[
@ -784,73 +656,6 @@ class OpenAIResponsesImpl:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools, mcp_tool_to_server, mcp_list_message return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_tool_calls_only(
self,
choice: OpenAIChoice,
ctx: ChatCompletionContext,
) -> tuple[list[OpenAIResponseOutput], list[OpenAIMessageParam]]:
"""Execute tool calls and return output messages and tool response messages for next inference."""
output_messages: list[OpenAIResponseOutput] = []
tool_response_messages: list[OpenAIMessageParam] = []
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages, tool_response_messages
if not choice.message.tool_calls:
return output_messages, tool_response_messages
for tool_call in choice.message.tool_calls:
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if further_input:
tool_response_messages.append(further_input)
return output_messages, tool_response_messages
async def _execute_tool_and_return_final_output(
self,
choice: OpenAIChoice,
ctx: ChatCompletionContext,
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages
if not choice.message.tool_calls:
return output_messages
next_turn_messages = ctx.messages.copy()
# Add the assistant message with tool_calls response to the messages list
next_turn_messages.append(choice.message)
for tool_call in choice.message.tool_calls:
# TODO: telemetry spans for tool calls
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if further_input:
next_turn_messages.append(further_input)
tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=next_turn_messages,
stream=ctx.stream,
temperature=ctx.temperature,
)
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
# Huge TODO: these are NOT the final outputs, we must keep the loop going
tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
]
# TODO: Wire in annotations with URLs, titles, etc to these output messages
output_messages.extend(tool_final_outputs)
return output_messages
async def _execute_tool_call( async def _execute_tool_call(
self, self,
tool_call: OpenAIChatCompletionToolCall, tool_call: OpenAIChatCompletionToolCall,
@ -939,3 +744,15 @@ class OpenAIResponsesImpl:
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message return message, input_message
def _is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False

View file

@ -345,21 +345,27 @@ class OllamaInferenceAdapter(
model = await self.register_helper.register_model(model) model = await self.register_helper.register_model(model)
except ValueError: except ValueError:
pass # Ignore statically unknown model, will check live listing pass # Ignore statically unknown model, will check live listing
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id) # TODO: you should pull here only if the model is not found in a list
response = await self.client.list()
if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id)
# we use list() here instead of ps() - # we use list() here instead of ps() -
# - ps() only lists running models, not available models # - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed # - models not currently running are run by the ollama server as needed
response = await self.client.list() response = await self.client.list()
available_models = [m["model"] for m in response["models"]] available_models = [m.model for m in response.models]
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id) provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None: if provider_resource_id is None:
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
if provider_resource_id not in available_models: if provider_resource_id not in available_models:
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]] available_models_latest = [m.model.split(":latest")[0] for m in response.models]
if provider_resource_id in available_models_latest: if provider_resource_id in available_models_latest:
logger.warning( logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"

View file

@ -80,6 +80,37 @@ def openai_responses_impl(mock_inference_api, mock_tool_groups_api, mock_tool_ru
) )
async def fake_stream(fixture: str = "simple_chat_completion.yaml"):
value = load_chat_completion_fixture(fixture)
yield ChatCompletionChunk(
id=value.id,
choices=[
Choice(
index=0,
delta=ChoiceDelta(
content=c.message.content,
role=c.message.role,
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id=t.id,
function=ChoiceDeltaToolCallFunction(
name=t.function.name,
arguments=t.function.arguments,
),
)
for t in (c.message.tool_calls or [])
],
),
)
for c in value.choices
],
created=1,
model=value.model,
object="chat.completion.chunk",
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input.""" """Test creating an OpenAI response with a simple string input."""
@ -88,8 +119,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
model = "meta-llama/Llama-3.1-8B-Instruct" model = "meta-llama/Llama-3.1-8B-Instruct"
# Load the chat completion fixture # Load the chat completion fixture
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml") mock_inference_api.openai_chat_completion.return_value = fake_stream()
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
# Execute # Execute
result = await openai_responses_impl.create_openai_response( result = await openai_responses_impl.create_openai_response(
@ -104,7 +134,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)], messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=OpenAIResponseFormatText(), response_format=OpenAIResponseFormatText(),
tools=None, tools=None,
stream=False, stream=True,
temperature=0.1, temperature=0.1,
) )
openai_responses_impl.responses_store.store_response_object.assert_called_once() openai_responses_impl.responses_store.store_response_object.assert_called_once()
@ -121,20 +151,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
input_text = "What is the capital of Ireland?" input_text = "What is the capital of Ireland?"
model = "meta-llama/Llama-3.1-8B-Instruct" model = "meta-llama/Llama-3.1-8B-Instruct"
# Load the chat completion fixtures
tool_call_completion = load_chat_completion_fixture("tool_call_completion.yaml")
tool_response_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
mock_inference_api.openai_chat_completion.side_effect = [ mock_inference_api.openai_chat_completion.side_effect = [
tool_call_completion, fake_stream("tool_call_completion.yaml"),
tool_response_completion, fake_stream(),
] ]
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool( openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
identifier="web_search", identifier="web_search",
provider_id="client", provider_id="client",
toolgroup_id="web_search", toolgroup_id="web_search",
tool_host="client",
description="Search the web for information", description="Search the web for information",
parameters=[ parameters=[
ToolParameter(name="query", parameter_type="string", description="The query to search for", required=True) ToolParameter(name="query", parameter_type="string", description="The query to search for", required=True)
@ -189,7 +214,7 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
input_text = "How hot it is in San Francisco today?" input_text = "How hot it is in San Francisco today?"
model = "meta-llama/Llama-3.1-8B-Instruct" model = "meta-llama/Llama-3.1-8B-Instruct"
async def fake_stream(): async def fake_stream_toolcall():
yield ChatCompletionChunk( yield ChatCompletionChunk(
id="123", id="123",
choices=[ choices=[
@ -212,7 +237,7 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
object="chat.completion.chunk", object="chat.completion.chunk",
) )
mock_inference_api.openai_chat_completion.return_value = fake_stream() mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
# Execute # Execute
result = await openai_responses_impl.create_openai_response( result = await openai_responses_impl.create_openai_response(
@ -271,7 +296,7 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
] ]
model = "meta-llama/Llama-3.1-8B-Instruct" model = "meta-llama/Llama-3.1-8B-Instruct"
mock_inference_api.openai_chat_completion.return_value = load_chat_completion_fixture("simple_chat_completion.yaml") mock_inference_api.openai_chat_completion.return_value = fake_stream()
# Execute # Execute
await openai_responses_impl.create_openai_response( await openai_responses_impl.create_openai_response(
@ -399,9 +424,7 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
model = "meta-llama/Llama-3.1-8B-Instruct" model = "meta-llama/Llama-3.1-8B-Instruct"
instructions = "You are a geography expert. Provide concise answers." instructions = "You are a geography expert. Provide concise answers."
# Load the chat completion fixture mock_inference_api.openai_chat_completion.return_value = fake_stream()
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
# Execute # Execute
await openai_responses_impl.create_openai_response( await openai_responses_impl.create_openai_response(
@ -440,8 +463,7 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
model = "meta-llama/Llama-3.1-8B-Instruct" model = "meta-llama/Llama-3.1-8B-Instruct"
instructions = "You are a geography expert. Provide concise answers." instructions = "You are a geography expert. Provide concise answers."
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml") mock_inference_api.openai_chat_completion.return_value = fake_stream()
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
# Execute # Execute
await openai_responses_impl.create_openai_response( await openai_responses_impl.create_openai_response(
@ -499,8 +521,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
model = "meta-llama/Llama-3.1-8B-Instruct" model = "meta-llama/Llama-3.1-8B-Instruct"
instructions = "You are a geography expert. Provide concise answers." instructions = "You are a geography expert. Provide concise answers."
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion mock_inference_api.openai_chat_completion.return_value = fake_stream()
# Execute # Execute
await openai_responses_impl.create_openai_response( await openai_responses_impl.create_openai_response(
@ -674,8 +696,8 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
current_input = "Now what is 3+3?" current_input = "Now what is 3+3?"
model = "meta-llama/Llama-3.1-8B-Instruct" model = "meta-llama/Llama-3.1-8B-Instruct"
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion mock_inference_api.openai_chat_completion.return_value = fake_stream()
# Execute - Create response with previous_response_id # Execute - Create response with previous_response_id
result = await openai_responses_impl.create_openai_response( result = await openai_responses_impl.create_openai_response(
@ -732,9 +754,7 @@ async def test_create_openai_response_with_text_format(
input_text = "How hot it is in San Francisco today?" input_text = "How hot it is in San Francisco today?"
model = "meta-llama/Llama-3.1-8B-Instruct" model = "meta-llama/Llama-3.1-8B-Instruct"
# Load the chat completion fixture mock_inference_api.openai_chat_completion.return_value = fake_stream()
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
# Execute # Execute
_result = await openai_responses_impl.create_openai_response( _result = await openai_responses_impl.create_openai_response(