From 924213a689c284cacf8be26c9a09146ee7d40811 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 30 Apr 2025 17:01:00 -0400 Subject: [PATCH] Responses API: Finish wiring up function tool calls This finishes the plumbing for function tool call and adds a basic verification test (that passes for me locally against Llama 4 Scout in vllm). Signed-off-by: Ben Browning --- docs/_static/llama-stack-spec.html | 41 ++++++++++++++++++- docs/_static/llama-stack-spec.yaml | 29 +++++++++++++ llama_stack/apis/agents/openai_responses.py | 14 ++++++- .../agents/meta_reference/openai_responses.py | 28 +++++++++++-- .../fixtures/test_cases/responses.yaml | 20 +++++++++ .../openai_api/test_responses.py | 22 ++++++++++ 6 files changed, 148 insertions(+), 6 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b3fed05cc..15342de86 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -6920,16 +6920,55 @@ }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall" + }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" } ], "discriminator": { "propertyName": "type", "mapping": { "message": "#/components/schemas/OpenAIResponseMessage", - "web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall" + "web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall", + "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" } } }, + "OpenAIResponseOutputMessageFunctionToolCall": { + "type": "object", + "properties": { + "arguments": { + "type": "string" + }, + "call_id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "type": { + "type": "string", + "const": "function_call", + "default": "function_call" + }, + "id": { + "type": "string" + }, + "status": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "arguments", + "call_id", + "name", + "type", + "id", + "status" + ], + "title": "OpenAIResponseOutputMessageFunctionToolCall" + }, "OpenAIResponseObjectStream": { "oneOf": [ { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index bf003783f..bc71ce915 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -4833,11 +4833,40 @@ components: oneOf: - $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' discriminator: propertyName: type mapping: message: '#/components/schemas/OpenAIResponseMessage' web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' + function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' + "OpenAIResponseOutputMessageFunctionToolCall": + type: object + properties: + arguments: + type: string + call_id: + type: string + name: + type: string + type: + type: string + const: function_call + default: function_call + id: + type: string + status: + type: string + additionalProperties: false + required: + - arguments + - call_id + - name + - type + - id + - status + title: >- + OpenAIResponseOutputMessageFunctionToolCall OpenAIResponseObjectStream: oneOf: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 55005fa50..39ef806ef 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -77,8 +77,20 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): type: Literal["web_search_call"] = "web_search_call" +@json_schema_type +class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): + arguments: str + call_id: str + name: str + type: Literal["function_call"] = "function_call" + id: str + status: str + + OpenAIResponseOutput = Annotated[ - OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall, + OpenAIResponseMessage + | OpenAIResponseOutputMessageWebSearchToolCall + | OpenAIResponseOutputMessageFunctionToolCall, Field(discriminator="type"), ] register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 990d864ec..1a2dd1ff6 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -17,6 +17,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, OpenAIResponseInputTool, + OpenAIResponseInputToolFunction, OpenAIResponseMessage, OpenAIResponseObject, OpenAIResponseObjectStream, @@ -24,6 +25,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObjectStreamResponseCreated, OpenAIResponseOutput, OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageFunctionToolCall, OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponsePreviousResponseWithInputItems, ) @@ -221,10 +223,28 @@ class OpenAIResponsesImpl: chat_response = OpenAIChatCompletion(**chat_response.model_dump()) output_messages: list[OpenAIResponseOutput] = [] - if chat_response.choices[0].message.tool_calls: - output_messages.extend( - await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature) - ) + # TODO: should we check more than choices[0] here? + if chat_response.choices[0].message.tool_calls and tools: + # TODO: Should we support a mix of custom and builtin tools? + # in other words, should we check for more than tools[0]? + if isinstance(tools[0], OpenAIResponseInputToolFunction): + choice = chat_response.choices[0] + for tool_call in choice.message.tool_calls: + output_messages.append( + OpenAIResponseOutputMessageFunctionToolCall( + arguments=tool_call.function.arguments or "", + call_id=tool_call.id, + name=tool_call.function.name or "", + id=f"fc_{uuid.uuid4()}", + status="completed", + ) + ) + else: + output_messages.extend( + await self._execute_tool_and_return_final_output( + model, stream, chat_response, messages, temperature + ) + ) else: output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices)) response = OpenAIResponseObject( diff --git a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml index f235b2ea8..ed5f571e8 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml @@ -31,6 +31,26 @@ test_response_web_search: search_context_size: "low" output: "128" +test_response_custom_tool: + test_name: test_response_custom_tool + test_params: + case: + - case_id: "sf_weather" + input: "What's the weather like in San Francisco?" + tools: + - type: function + name: get_weather + description: Get current temperature for a given location. + parameters: + additionalProperties: false + properties: + location: + description: "City and country e.g. Bogot\xE1, Colombia" + type: string + required: + - location + type: object + test_response_image: test_name: test_response_image test_params: diff --git a/tests/verifications/openai_api/test_responses.py b/tests/verifications/openai_api/test_responses.py index cc7ec320c..e279b9b38 100644 --- a/tests/verifications/openai_api/test_responses.py +++ b/tests/verifications/openai_api/test_responses.py @@ -124,6 +124,28 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid assert case["output"].lower() in response.output_text.lower().strip() +@pytest.mark.parametrize( + "case", + responses_test_cases["test_response_custom_tool"]["test_params"]["case"], + ids=case_id_generator, +) +def test_response_non_streaming_custom_tool(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + response = openai_client.responses.create( + model=model, + input=case["input"], + tools=case["tools"], + stream=False, + ) + assert len(response.output) == 1 + assert response.output[0].type == "function_call" + assert response.output[0].status == "completed" + assert response.output[0].name == "get_weather" + + @pytest.mark.parametrize( "case", responses_test_cases["test_response_image"]["test_params"]["case"],