mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
refactor: unify stream and non-stream impls for responses (#2388)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Integration Tests / test-matrix (http, datasets) (push) Failing after 9s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, inspect) (push) Failing after 8s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, providers) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Test External Providers / test-external-providers (venv) (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 11s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 7s
Unit Tests / unit-tests (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 30s
Pre-commit / pre-commit (push) Successful in 1m18s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Integration Tests / test-matrix (http, datasets) (push) Failing after 9s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, inspect) (push) Failing after 8s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, providers) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Test External Providers / test-external-providers (venv) (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 11s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 7s
Unit Tests / unit-tests (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 30s
Pre-commit / pre-commit (push) Successful in 1m18s
The non-streaming version is just a small layer on top of the streaming version - just pluck off the final `response.completed` event and return that as the response! This PR also includes a couple other changes which I ended up making while working on it on a flight: - changes to `ollama` so it does not pull embedding models unconditionally - a small fix to library client to make the stream and non-stream cases a bit more symmetric
This commit is contained in:
parent
ef885d2147
commit
3251b44d8a
4 changed files with 166 additions and 315 deletions
|
@ -149,12 +149,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
|
# NOTE: We are using AsyncLlamaStackClient under the hood
|
||||||
|
# A new event loop is needed to convert the AsyncStream
|
||||||
|
# from async client into SyncStream return type for streaming
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
# NOTE: We are using AsyncLlamaStackClient under the hood
|
|
||||||
# A new event loop is needed to convert the AsyncStream
|
|
||||||
# from async client into SyncStream return type for streaming
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
def sync_generator():
|
def sync_generator():
|
||||||
try:
|
try:
|
||||||
|
@ -172,7 +173,14 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
|
||||||
return sync_generator()
|
return sync_generator()
|
||||||
else:
|
else:
|
||||||
return asyncio.run(self.async_client.request(*args, **kwargs))
|
try:
|
||||||
|
result = loop.run_until_complete(self.async_client.request(*args, **kwargs))
|
||||||
|
finally:
|
||||||
|
pending = asyncio.all_tasks(loop)
|
||||||
|
if pending:
|
||||||
|
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||||
|
loop.close()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
|
@ -8,7 +8,7 @@ import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any, cast
|
from typing import Any
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -200,7 +200,6 @@ class ChatCompletionContext(BaseModel):
|
||||||
messages: list[OpenAIMessageParam]
|
messages: list[OpenAIMessageParam]
|
||||||
tools: list[ChatCompletionToolParam] | None = None
|
tools: list[ChatCompletionToolParam] | None = None
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||||
stream: bool
|
|
||||||
temperature: float | None
|
temperature: float | None
|
||||||
response_format: OpenAIResponseFormatParam
|
response_format: OpenAIResponseFormatParam
|
||||||
|
|
||||||
|
@ -281,49 +280,6 @@ class OpenAIResponsesImpl:
|
||||||
"""
|
"""
|
||||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||||
|
|
||||||
def _is_function_tool_call(
|
|
||||||
self,
|
|
||||||
tool_call: OpenAIChatCompletionToolCall,
|
|
||||||
tools: list[OpenAIResponseInputTool],
|
|
||||||
) -> bool:
|
|
||||||
if not tool_call.function:
|
|
||||||
return False
|
|
||||||
for t in tools:
|
|
||||||
if t.type == "function" and t.name == tool_call.function.name:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _process_response_choices(
|
|
||||||
self,
|
|
||||||
chat_response: OpenAIChatCompletion,
|
|
||||||
ctx: ChatCompletionContext,
|
|
||||||
tools: list[OpenAIResponseInputTool] | None,
|
|
||||||
) -> list[OpenAIResponseOutput]:
|
|
||||||
"""Handle tool execution and response message creation."""
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
|
||||||
# Execute tool calls if any
|
|
||||||
for choice in chat_response.choices:
|
|
||||||
if choice.message.tool_calls and tools:
|
|
||||||
# Assume if the first tool is a function, all tools are functions
|
|
||||||
if self._is_function_tool_call(choice.message.tool_calls[0], tools):
|
|
||||||
for tool_call in choice.message.tool_calls:
|
|
||||||
output_messages.append(
|
|
||||||
OpenAIResponseOutputMessageFunctionToolCall(
|
|
||||||
arguments=tool_call.function.arguments or "",
|
|
||||||
call_id=tool_call.id,
|
|
||||||
name=tool_call.function.name or "",
|
|
||||||
id=f"fc_{uuid.uuid4()}",
|
|
||||||
status="completed",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
|
|
||||||
output_messages.extend(tool_messages)
|
|
||||||
else:
|
|
||||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
|
||||||
|
|
||||||
return output_messages
|
|
||||||
|
|
||||||
async def _store_response(
|
async def _store_response(
|
||||||
self,
|
self,
|
||||||
response: OpenAIResponseObject,
|
response: OpenAIResponseObject,
|
||||||
|
@ -370,9 +326,48 @@ class OpenAIResponsesImpl:
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
):
|
):
|
||||||
stream = False if stream is None else stream
|
stream = bool(stream)
|
||||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||||
|
|
||||||
|
stream_gen = self._create_streaming_response(
|
||||||
|
input=input,
|
||||||
|
model=model,
|
||||||
|
instructions=instructions,
|
||||||
|
previous_response_id=previous_response_id,
|
||||||
|
store=store,
|
||||||
|
temperature=temperature,
|
||||||
|
text=text,
|
||||||
|
tools=tools,
|
||||||
|
max_infer_iters=max_infer_iters,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return stream_gen
|
||||||
|
else:
|
||||||
|
response = None
|
||||||
|
async for stream_chunk in stream_gen:
|
||||||
|
if stream_chunk.type == "response.completed":
|
||||||
|
if response is not None:
|
||||||
|
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
|
||||||
|
response = stream_chunk.response
|
||||||
|
# don't leave the generator half complete!
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise ValueError("The response stream never completed")
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _create_streaming_response(
|
||||||
|
self,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
model: str,
|
||||||
|
instructions: str | None = None,
|
||||||
|
previous_response_id: str | None = None,
|
||||||
|
store: bool | None = True,
|
||||||
|
temperature: float | None = None,
|
||||||
|
text: OpenAIResponseText | None = None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
max_infer_iters: int | None = 10,
|
||||||
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
# Input preprocessing
|
# Input preprocessing
|
||||||
|
@ -383,7 +378,7 @@ class OpenAIResponsesImpl:
|
||||||
# Structured outputs
|
# Structured outputs
|
||||||
response_format = await _convert_response_text_to_chat_response_format(text)
|
response_format = await _convert_response_text_to_chat_response_format(text)
|
||||||
|
|
||||||
# Tool setup
|
# Tool setup, TODO: refactor this slightly since this can also yield events
|
||||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||||
)
|
)
|
||||||
|
@ -395,136 +390,10 @@ class OpenAIResponsesImpl:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=chat_tools,
|
tools=chat_tools,
|
||||||
mcp_tool_to_server=mcp_tool_to_server,
|
mcp_tool_to_server=mcp_tool_to_server,
|
||||||
stream=stream,
|
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fork to streaming vs non-streaming - let each handle ALL inference rounds
|
|
||||||
if stream:
|
|
||||||
return self._create_streaming_response(
|
|
||||||
ctx=ctx,
|
|
||||||
output_messages=output_messages,
|
|
||||||
input=input,
|
|
||||||
model=model,
|
|
||||||
store=store,
|
|
||||||
text=text,
|
|
||||||
tools=tools,
|
|
||||||
max_infer_iters=max_infer_iters,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return await self._create_non_streaming_response(
|
|
||||||
ctx=ctx,
|
|
||||||
output_messages=output_messages,
|
|
||||||
input=input,
|
|
||||||
model=model,
|
|
||||||
store=store,
|
|
||||||
text=text,
|
|
||||||
tools=tools,
|
|
||||||
max_infer_iters=max_infer_iters,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _create_non_streaming_response(
|
|
||||||
self,
|
|
||||||
ctx: ChatCompletionContext,
|
|
||||||
output_messages: list[OpenAIResponseOutput],
|
|
||||||
input: str | list[OpenAIResponseInput],
|
|
||||||
model: str,
|
|
||||||
store: bool | None,
|
|
||||||
text: OpenAIResponseText,
|
|
||||||
tools: list[OpenAIResponseInputTool] | None,
|
|
||||||
max_infer_iters: int,
|
|
||||||
) -> OpenAIResponseObject:
|
|
||||||
n_iter = 0
|
|
||||||
messages = ctx.messages.copy()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Do inference (including the first one)
|
|
||||||
inference_result = await self.inference_api.openai_chat_completion(
|
|
||||||
model=ctx.model,
|
|
||||||
messages=messages,
|
|
||||||
tools=ctx.tools,
|
|
||||||
stream=False,
|
|
||||||
temperature=ctx.temperature,
|
|
||||||
response_format=ctx.response_format,
|
|
||||||
)
|
|
||||||
completion = OpenAIChatCompletion(**inference_result.model_dump())
|
|
||||||
|
|
||||||
# Separate function vs non-function tool calls
|
|
||||||
function_tool_calls = []
|
|
||||||
non_function_tool_calls = []
|
|
||||||
|
|
||||||
for choice in completion.choices:
|
|
||||||
if choice.message.tool_calls and tools:
|
|
||||||
for tool_call in choice.message.tool_calls:
|
|
||||||
if self._is_function_tool_call(tool_call, tools):
|
|
||||||
function_tool_calls.append(tool_call)
|
|
||||||
else:
|
|
||||||
non_function_tool_calls.append(tool_call)
|
|
||||||
|
|
||||||
# Process response choices based on tool call types
|
|
||||||
if function_tool_calls:
|
|
||||||
# For function tool calls, use existing logic and return immediately
|
|
||||||
current_output_messages = await self._process_response_choices(
|
|
||||||
chat_response=completion,
|
|
||||||
ctx=ctx,
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
output_messages.extend(current_output_messages)
|
|
||||||
break
|
|
||||||
elif non_function_tool_calls:
|
|
||||||
# For non-function tool calls, execute them and continue loop
|
|
||||||
for choice in completion.choices:
|
|
||||||
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
|
|
||||||
output_messages.extend(tool_outputs)
|
|
||||||
|
|
||||||
# Add assistant message and tool responses to messages for next iteration
|
|
||||||
messages.append(choice.message)
|
|
||||||
messages.extend(tool_response_messages)
|
|
||||||
|
|
||||||
n_iter += 1
|
|
||||||
if n_iter >= max_infer_iters:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Continue with next iteration of the loop
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# No tool calls - convert response to message and we're done
|
|
||||||
for choice in completion.choices:
|
|
||||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
|
||||||
break
|
|
||||||
|
|
||||||
response = OpenAIResponseObject(
|
|
||||||
created_at=completion.created,
|
|
||||||
id=f"resp-{uuid.uuid4()}",
|
|
||||||
model=model,
|
|
||||||
object="response",
|
|
||||||
status="completed",
|
|
||||||
output=output_messages,
|
|
||||||
text=text,
|
|
||||||
)
|
|
||||||
logger.debug(f"OpenAI Responses response: {response}")
|
|
||||||
|
|
||||||
# Store response if requested
|
|
||||||
if store:
|
|
||||||
await self._store_response(
|
|
||||||
response=response,
|
|
||||||
input=input,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def _create_streaming_response(
|
|
||||||
self,
|
|
||||||
ctx: ChatCompletionContext,
|
|
||||||
output_messages: list[OpenAIResponseOutput],
|
|
||||||
input: str | list[OpenAIResponseInput],
|
|
||||||
model: str,
|
|
||||||
store: bool | None,
|
|
||||||
text: OpenAIResponseText,
|
|
||||||
tools: list[OpenAIResponseInputTool] | None,
|
|
||||||
max_infer_iters: int | None,
|
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
|
||||||
# Create initial response and emit response.created immediately
|
# Create initial response and emit response.created immediately
|
||||||
response_id = f"resp-{uuid.uuid4()}"
|
response_id = f"resp-{uuid.uuid4()}"
|
||||||
created_at = int(time.time())
|
created_at = int(time.time())
|
||||||
|
@ -539,15 +408,13 @@ class OpenAIResponsesImpl:
|
||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit response.created immediately
|
|
||||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||||
|
|
||||||
# Implement tool execution loop for streaming - handle ALL inference rounds including the first
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
messages = ctx.messages.copy()
|
messages = ctx.messages.copy()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
current_inference_result = await self.inference_api.openai_chat_completion(
|
completion_result = await self.inference_api.openai_chat_completion(
|
||||||
model=ctx.model,
|
model=ctx.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=ctx.tools,
|
tools=ctx.tools,
|
||||||
|
@ -568,7 +435,7 @@ class OpenAIResponsesImpl:
|
||||||
# Create a placeholder message item for delta events
|
# Create a placeholder message item for delta events
|
||||||
message_item_id = f"msg_{uuid.uuid4()}"
|
message_item_id = f"msg_{uuid.uuid4()}"
|
||||||
|
|
||||||
async for chunk in current_inference_result:
|
async for chunk in completion_result:
|
||||||
chat_response_id = chunk.id
|
chat_response_id = chunk.id
|
||||||
chunk_created = chunk.created
|
chunk_created = chunk.created
|
||||||
chunk_model = chunk.model
|
chunk_model = chunk.model
|
||||||
|
@ -628,50 +495,55 @@ class OpenAIResponsesImpl:
|
||||||
model=chunk_model,
|
model=chunk_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Separate function vs non-function tool calls
|
|
||||||
function_tool_calls = []
|
function_tool_calls = []
|
||||||
non_function_tool_calls = []
|
non_function_tool_calls = []
|
||||||
|
|
||||||
|
next_turn_messages = messages.copy()
|
||||||
for choice in current_response.choices:
|
for choice in current_response.choices:
|
||||||
|
next_turn_messages.append(choice.message)
|
||||||
|
|
||||||
if choice.message.tool_calls and tools:
|
if choice.message.tool_calls and tools:
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
if self._is_function_tool_call(tool_call, tools):
|
if _is_function_tool_call(tool_call, tools):
|
||||||
function_tool_calls.append(tool_call)
|
function_tool_calls.append(tool_call)
|
||||||
else:
|
else:
|
||||||
non_function_tool_calls.append(tool_call)
|
non_function_tool_calls.append(tool_call)
|
||||||
|
else:
|
||||||
# Process response choices based on tool call types
|
|
||||||
if function_tool_calls:
|
|
||||||
# For function tool calls, use existing logic and break
|
|
||||||
current_output_messages = await self._process_response_choices(
|
|
||||||
chat_response=current_response,
|
|
||||||
ctx=ctx,
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
output_messages.extend(current_output_messages)
|
|
||||||
break
|
|
||||||
elif non_function_tool_calls:
|
|
||||||
# For non-function tool calls, execute them and continue loop
|
|
||||||
for choice in current_response.choices:
|
|
||||||
tool_outputs, tool_response_messages = await self._execute_tool_calls_only(choice, ctx)
|
|
||||||
output_messages.extend(tool_outputs)
|
|
||||||
|
|
||||||
# Add assistant message and tool responses to messages for next iteration
|
|
||||||
messages.append(choice.message)
|
|
||||||
messages.extend(tool_response_messages)
|
|
||||||
|
|
||||||
n_iter += 1
|
|
||||||
if n_iter >= (max_infer_iters or 10):
|
|
||||||
break
|
|
||||||
|
|
||||||
# Continue with next iteration of the loop
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# No tool calls - convert response to message and we're done
|
|
||||||
for choice in current_response.choices:
|
|
||||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||||
|
|
||||||
|
# execute non-function tool calls
|
||||||
|
for tool_call in non_function_tool_calls:
|
||||||
|
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
|
||||||
|
if tool_call_log:
|
||||||
|
output_messages.append(tool_call_log)
|
||||||
|
if tool_response_message:
|
||||||
|
next_turn_messages.append(tool_response_message)
|
||||||
|
|
||||||
|
for tool_call in function_tool_calls:
|
||||||
|
output_messages.append(
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
arguments=tool_call.function.arguments or "",
|
||||||
|
call_id=tool_call.id,
|
||||||
|
name=tool_call.function.name or "",
|
||||||
|
id=f"fc_{uuid.uuid4()}",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not function_tool_calls and not non_function_tool_calls:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if function_tool_calls:
|
||||||
|
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||||
|
break
|
||||||
|
|
||||||
|
n_iter += 1
|
||||||
|
if n_iter >= max_infer_iters:
|
||||||
|
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
|
||||||
|
break
|
||||||
|
|
||||||
|
messages = next_turn_messages
|
||||||
|
|
||||||
# Create final response
|
# Create final response
|
||||||
final_response = OpenAIResponseObject(
|
final_response = OpenAIResponseObject(
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
|
@ -683,15 +555,15 @@ class OpenAIResponsesImpl:
|
||||||
output=output_messages,
|
output=output_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Emit response.completed
|
||||||
|
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||||
|
|
||||||
if store:
|
if store:
|
||||||
await self._store_response(
|
await self._store_response(
|
||||||
response=final_response,
|
response=final_response,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit response.completed
|
|
||||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
|
||||||
|
|
||||||
async def _convert_response_tools_to_chat_tools(
|
async def _convert_response_tools_to_chat_tools(
|
||||||
self, tools: list[OpenAIResponseInputTool]
|
self, tools: list[OpenAIResponseInputTool]
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
|
@ -784,73 +656,6 @@ class OpenAIResponsesImpl:
|
||||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||||
|
|
||||||
async def _execute_tool_calls_only(
|
|
||||||
self,
|
|
||||||
choice: OpenAIChoice,
|
|
||||||
ctx: ChatCompletionContext,
|
|
||||||
) -> tuple[list[OpenAIResponseOutput], list[OpenAIMessageParam]]:
|
|
||||||
"""Execute tool calls and return output messages and tool response messages for next inference."""
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
|
||||||
tool_response_messages: list[OpenAIMessageParam] = []
|
|
||||||
|
|
||||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
|
||||||
return output_messages, tool_response_messages
|
|
||||||
|
|
||||||
if not choice.message.tool_calls:
|
|
||||||
return output_messages, tool_response_messages
|
|
||||||
|
|
||||||
for tool_call in choice.message.tool_calls:
|
|
||||||
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
|
||||||
if tool_call_log:
|
|
||||||
output_messages.append(tool_call_log)
|
|
||||||
if further_input:
|
|
||||||
tool_response_messages.append(further_input)
|
|
||||||
|
|
||||||
return output_messages, tool_response_messages
|
|
||||||
|
|
||||||
async def _execute_tool_and_return_final_output(
|
|
||||||
self,
|
|
||||||
choice: OpenAIChoice,
|
|
||||||
ctx: ChatCompletionContext,
|
|
||||||
) -> list[OpenAIResponseOutput]:
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
|
||||||
|
|
||||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
|
||||||
return output_messages
|
|
||||||
|
|
||||||
if not choice.message.tool_calls:
|
|
||||||
return output_messages
|
|
||||||
|
|
||||||
next_turn_messages = ctx.messages.copy()
|
|
||||||
|
|
||||||
# Add the assistant message with tool_calls response to the messages list
|
|
||||||
next_turn_messages.append(choice.message)
|
|
||||||
|
|
||||||
for tool_call in choice.message.tool_calls:
|
|
||||||
# TODO: telemetry spans for tool calls
|
|
||||||
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
|
||||||
if tool_call_log:
|
|
||||||
output_messages.append(tool_call_log)
|
|
||||||
if further_input:
|
|
||||||
next_turn_messages.append(further_input)
|
|
||||||
|
|
||||||
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
|
||||||
model=ctx.model,
|
|
||||||
messages=next_turn_messages,
|
|
||||||
stream=ctx.stream,
|
|
||||||
temperature=ctx.temperature,
|
|
||||||
)
|
|
||||||
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
|
|
||||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
|
||||||
|
|
||||||
# Huge TODO: these are NOT the final outputs, we must keep the loop going
|
|
||||||
tool_final_outputs = [
|
|
||||||
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
|
||||||
]
|
|
||||||
# TODO: Wire in annotations with URLs, titles, etc to these output messages
|
|
||||||
output_messages.extend(tool_final_outputs)
|
|
||||||
return output_messages
|
|
||||||
|
|
||||||
async def _execute_tool_call(
|
async def _execute_tool_call(
|
||||||
self,
|
self,
|
||||||
tool_call: OpenAIChatCompletionToolCall,
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
|
@ -939,3 +744,15 @@ class OpenAIResponsesImpl:
|
||||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
return message, input_message
|
return message, input_message
|
||||||
|
|
||||||
|
|
||||||
|
def _is_function_tool_call(
|
||||||
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
|
tools: list[OpenAIResponseInputTool],
|
||||||
|
) -> bool:
|
||||||
|
if not tool_call.function:
|
||||||
|
return False
|
||||||
|
for t in tools:
|
||||||
|
if t.type == "function" and t.name == tool_call.function.name:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
|
@ -345,21 +345,27 @@ class OllamaInferenceAdapter(
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass # Ignore statically unknown model, will check live listing
|
pass # Ignore statically unknown model, will check live listing
|
||||||
|
|
||||||
|
if model.provider_resource_id is None:
|
||||||
|
raise ValueError("Model provider_resource_id cannot be None")
|
||||||
|
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||||
await self.client.pull(model.provider_resource_id)
|
# TODO: you should pull here only if the model is not found in a list
|
||||||
|
response = await self.client.list()
|
||||||
|
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||||
|
await self.client.pull(model.provider_resource_id)
|
||||||
|
|
||||||
# we use list() here instead of ps() -
|
# we use list() here instead of ps() -
|
||||||
# - ps() only lists running models, not available models
|
# - ps() only lists running models, not available models
|
||||||
# - models not currently running are run by the ollama server as needed
|
# - models not currently running are run by the ollama server as needed
|
||||||
response = await self.client.list()
|
response = await self.client.list()
|
||||||
available_models = [m["model"] for m in response["models"]]
|
available_models = [m.model for m in response.models]
|
||||||
if model.provider_resource_id is None:
|
|
||||||
raise ValueError("Model provider_resource_id cannot be None")
|
|
||||||
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
||||||
if provider_resource_id is None:
|
if provider_resource_id is None:
|
||||||
provider_resource_id = model.provider_resource_id
|
provider_resource_id = model.provider_resource_id
|
||||||
if provider_resource_id not in available_models:
|
if provider_resource_id not in available_models:
|
||||||
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
|
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||||
if provider_resource_id in available_models_latest:
|
if provider_resource_id in available_models_latest:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||||
|
|
|
@ -80,6 +80,37 @@ def openai_responses_impl(mock_inference_api, mock_tool_groups_api, mock_tool_ru
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def fake_stream(fixture: str = "simple_chat_completion.yaml"):
|
||||||
|
value = load_chat_completion_fixture(fixture)
|
||||||
|
yield ChatCompletionChunk(
|
||||||
|
id=value.id,
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
content=c.message.content,
|
||||||
|
role=c.message.role,
|
||||||
|
tool_calls=[
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
id=t.id,
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name=t.function.name,
|
||||||
|
arguments=t.function.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for t in (c.message.tool_calls or [])
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for c in value.choices
|
||||||
|
],
|
||||||
|
created=1,
|
||||||
|
model=value.model,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
||||||
"""Test creating an OpenAI response with a simple string input."""
|
"""Test creating an OpenAI response with a simple string input."""
|
||||||
|
@ -88,8 +119,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
# Load the chat completion fixture
|
# Load the chat completion fixture
|
||||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await openai_responses_impl.create_openai_response(
|
result = await openai_responses_impl.create_openai_response(
|
||||||
|
@ -104,7 +134,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
||||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||||
response_format=OpenAIResponseFormatText(),
|
response_format=OpenAIResponseFormatText(),
|
||||||
tools=None,
|
tools=None,
|
||||||
stream=False,
|
stream=True,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
)
|
)
|
||||||
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
||||||
|
@ -121,20 +151,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
||||||
input_text = "What is the capital of Ireland?"
|
input_text = "What is the capital of Ireland?"
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
# Load the chat completion fixtures
|
|
||||||
tool_call_completion = load_chat_completion_fixture("tool_call_completion.yaml")
|
|
||||||
tool_response_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
|
||||||
|
|
||||||
mock_inference_api.openai_chat_completion.side_effect = [
|
mock_inference_api.openai_chat_completion.side_effect = [
|
||||||
tool_call_completion,
|
fake_stream("tool_call_completion.yaml"),
|
||||||
tool_response_completion,
|
fake_stream(),
|
||||||
]
|
]
|
||||||
|
|
||||||
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
|
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
|
||||||
identifier="web_search",
|
identifier="web_search",
|
||||||
provider_id="client",
|
provider_id="client",
|
||||||
toolgroup_id="web_search",
|
toolgroup_id="web_search",
|
||||||
tool_host="client",
|
|
||||||
description="Search the web for information",
|
description="Search the web for information",
|
||||||
parameters=[
|
parameters=[
|
||||||
ToolParameter(name="query", parameter_type="string", description="The query to search for", required=True)
|
ToolParameter(name="query", parameter_type="string", description="The query to search for", required=True)
|
||||||
|
@ -189,7 +214,7 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
||||||
input_text = "How hot it is in San Francisco today?"
|
input_text = "How hot it is in San Francisco today?"
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
async def fake_stream():
|
async def fake_stream_toolcall():
|
||||||
yield ChatCompletionChunk(
|
yield ChatCompletionChunk(
|
||||||
id="123",
|
id="123",
|
||||||
choices=[
|
choices=[
|
||||||
|
@ -212,7 +237,7 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
||||||
object="chat.completion.chunk",
|
object="chat.completion.chunk",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
result = await openai_responses_impl.create_openai_response(
|
result = await openai_responses_impl.create_openai_response(
|
||||||
|
@ -271,7 +296,7 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
||||||
]
|
]
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
mock_inference_api.openai_chat_completion.return_value = load_chat_completion_fixture("simple_chat_completion.yaml")
|
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
await openai_responses_impl.create_openai_response(
|
await openai_responses_impl.create_openai_response(
|
||||||
|
@ -399,9 +424,7 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
instructions = "You are a geography expert. Provide concise answers."
|
instructions = "You are a geography expert. Provide concise answers."
|
||||||
|
|
||||||
# Load the chat completion fixture
|
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
|
||||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
await openai_responses_impl.create_openai_response(
|
await openai_responses_impl.create_openai_response(
|
||||||
|
@ -440,8 +463,7 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
instructions = "You are a geography expert. Provide concise answers."
|
instructions = "You are a geography expert. Provide concise answers."
|
||||||
|
|
||||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
await openai_responses_impl.create_openai_response(
|
await openai_responses_impl.create_openai_response(
|
||||||
|
@ -499,8 +521,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
||||||
|
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
instructions = "You are a geography expert. Provide concise answers."
|
instructions = "You are a geography expert. Provide concise answers."
|
||||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
|
||||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
await openai_responses_impl.create_openai_response(
|
await openai_responses_impl.create_openai_response(
|
||||||
|
@ -674,8 +696,8 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
||||||
|
|
||||||
current_input = "Now what is 3+3?"
|
current_input = "Now what is 3+3?"
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
|
||||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||||
|
|
||||||
# Execute - Create response with previous_response_id
|
# Execute - Create response with previous_response_id
|
||||||
result = await openai_responses_impl.create_openai_response(
|
result = await openai_responses_impl.create_openai_response(
|
||||||
|
@ -732,9 +754,7 @@ async def test_create_openai_response_with_text_format(
|
||||||
input_text = "How hot it is in San Francisco today?"
|
input_text = "How hot it is in San Francisco today?"
|
||||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
# Load the chat completion fixture
|
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
|
||||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
_result = await openai_responses_impl.create_openai_response(
|
_result = await openai_responses_impl.create_openai_response(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue