mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-12 21:58:38 +00:00
chore: refactor (chat)completions endpoints to use shared params struct (#3761)
# What does this PR do? Converts openai(_chat)_completions params to pydantic BaseModel to reduce code duplication across all providers. ## Test Plan CI --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/llamastack/llama-stack/pull/3761). * #3777 * __->__ #3761
This commit is contained in:
parent
6954fe2274
commit
80d58ab519
33 changed files with 599 additions and 890 deletions
|
@ -33,6 +33,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionRequest,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
|
@ -161,15 +162,17 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
chunks = [chunk async for chunk in result]
|
||||
|
||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
response_format=None,
|
||||
tools=None,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
OpenAIChatCompletionRequest(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
response_format=None,
|
||||
tools=None,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Should have content part events for text streaming
|
||||
|
@ -256,13 +259,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
|||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == "What is the capital of Ireland?"
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
assert second_call.kwargs["messages"][-1].content == "Dublin"
|
||||
assert second_call.kwargs["temperature"] == 0.1
|
||||
second_params = second_call.args[0]
|
||||
assert second_params.messages[-1].content == "Dublin"
|
||||
assert second_params.temperature == 0.1
|
||||
|
||||
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
||||
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
||||
|
@ -348,9 +353,10 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
@ -394,9 +400,10 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
|
||||
def assert_common_expectations(chunks) -> None:
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
assert len(chunks[0].response.output) == 0
|
||||
completed_chunk = chunks[-1]
|
||||
assert completed_chunk.type == "response.completed"
|
||||
|
@ -512,7 +519,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
|||
|
||||
# Verify the the correct messages were sent to the inference API i.e.
|
||||
# All of the responses message were convered to the chat completion message objects
|
||||
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
params = call_args.args[0]
|
||||
inference_messages = params.messages
|
||||
for i, m in enumerate(input_messages):
|
||||
if isinstance(m.content, str):
|
||||
assert inference_messages[i].content == m.content
|
||||
|
@ -680,7 +689,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 2
|
||||
|
@ -718,7 +728,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 4 # 1 system + 3 input messages
|
||||
|
@ -778,7 +789,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 4, sent_messages
|
||||
|
@ -1018,7 +1030,8 @@ async def test_reuse_mcp_tool_list(
|
|||
)
|
||||
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
tools_seen = second_call.kwargs["tools"]
|
||||
second_params = second_call.args[0]
|
||||
tools_seen = second_params.tools
|
||||
assert len(tools_seen) == 1
|
||||
assert tools_seen[0]["function"]["name"] == "test_tool"
|
||||
assert tools_seen[0]["function"]["description"] == "a test tool"
|
||||
|
@ -1065,8 +1078,9 @@ async def test_create_openai_response_with_text_format(
|
|||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["response_format"] == response_format
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.response_format == response_format
|
||||
|
||||
|
||||
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue