mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
refactor: unify stream and non-stream impls for responses (#2388)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Integration Tests / test-matrix (http, datasets) (push) Failing after 9s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, inspect) (push) Failing after 8s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, providers) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Test External Providers / test-external-providers (venv) (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 11s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 7s
Unit Tests / unit-tests (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 30s
Pre-commit / pre-commit (push) Successful in 1m18s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Integration Tests / test-matrix (http, datasets) (push) Failing after 9s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, inspect) (push) Failing after 8s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, providers) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 9s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Test External Providers / test-external-providers (venv) (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 11s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 7s
Unit Tests / unit-tests (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 30s
Pre-commit / pre-commit (push) Successful in 1m18s
The non-streaming version is just a small layer on top of the streaming version - just pluck off the final `response.completed` event and return that as the response! This PR also includes a couple other changes which I ended up making while working on it on a flight: - changes to `ollama` so it does not pull embedding models unconditionally - a small fix to library client to make the stream and non-stream cases a bit more symmetric
This commit is contained in:
parent
ef885d2147
commit
3251b44d8a
4 changed files with 166 additions and 315 deletions
|
@ -80,6 +80,37 @@ def openai_responses_impl(mock_inference_api, mock_tool_groups_api, mock_tool_ru
|
|||
)
|
||||
|
||||
|
||||
async def fake_stream(fixture: str = "simple_chat_completion.yaml"):
|
||||
value = load_chat_completion_fixture(fixture)
|
||||
yield ChatCompletionChunk(
|
||||
id=value.id,
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=c.message.content,
|
||||
role=c.message.role,
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id=t.id,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=t.function.name,
|
||||
arguments=t.function.arguments,
|
||||
),
|
||||
)
|
||||
for t in (c.message.tool_calls or [])
|
||||
],
|
||||
),
|
||||
)
|
||||
for c in value.choices
|
||||
],
|
||||
created=1,
|
||||
model=value.model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a simple string input."""
|
||||
|
@ -88,8 +119,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Load the chat completion fixture
|
||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Execute
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
|
@ -104,7 +134,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
response_format=OpenAIResponseFormatText(),
|
||||
tools=None,
|
||||
stream=False,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
)
|
||||
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
||||
|
@ -121,20 +151,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
|||
input_text = "What is the capital of Ireland?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Load the chat completion fixtures
|
||||
tool_call_completion = load_chat_completion_fixture("tool_call_completion.yaml")
|
||||
tool_response_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
|
||||
mock_inference_api.openai_chat_completion.side_effect = [
|
||||
tool_call_completion,
|
||||
tool_response_completion,
|
||||
fake_stream("tool_call_completion.yaml"),
|
||||
fake_stream(),
|
||||
]
|
||||
|
||||
openai_responses_impl.tool_groups_api.get_tool.return_value = Tool(
|
||||
identifier="web_search",
|
||||
provider_id="client",
|
||||
toolgroup_id="web_search",
|
||||
tool_host="client",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(name="query", parameter_type="string", description="The query to search for", required=True)
|
||||
|
@ -189,7 +214,7 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
input_text = "How hot it is in San Francisco today?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
async def fake_stream():
|
||||
async def fake_stream_toolcall():
|
||||
yield ChatCompletionChunk(
|
||||
id="123",
|
||||
choices=[
|
||||
|
@ -212,7 +237,7 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
|
||||
# Execute
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
|
@ -271,7 +296,7 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
|||
]
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl.create_openai_response(
|
||||
|
@ -399,9 +424,7 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
|||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
instructions = "You are a geography expert. Provide concise answers."
|
||||
|
||||
# Load the chat completion fixture
|
||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl.create_openai_response(
|
||||
|
@ -440,8 +463,7 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
|
|||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
instructions = "You are a geography expert. Provide concise answers."
|
||||
|
||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl.create_openai_response(
|
||||
|
@ -499,8 +521,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
instructions = "You are a geography expert. Provide concise answers."
|
||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl.create_openai_response(
|
||||
|
@ -674,8 +696,8 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
|
||||
current_input = "Now what is 3+3?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Execute - Create response with previous_response_id
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
|
@ -732,9 +754,7 @@ async def test_create_openai_response_with_text_format(
|
|||
input_text = "How hot it is in San Francisco today?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Load the chat completion fixture
|
||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Execute
|
||||
_result = await openai_responses_impl.create_openai_response(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue