mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
Implement include parameter specifically for adding logprobs in the output message
This commit is contained in:
parent
4ff0c25c52
commit
7d6c0aaf11
10 changed files with 255 additions and 8 deletions
|
|
@ -40,6 +40,8 @@ from llama_stack_api import (
|
|||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAITokenLogProb,
|
||||
OpenAITopLogProb,
|
||||
Order,
|
||||
RerankResponse,
|
||||
RoutingTable,
|
||||
|
|
@ -313,8 +315,34 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
if choice_delta.finish_reason:
|
||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||
|
||||
# Convert logprobs from chat completion format to responses format
|
||||
# Chat completion returns list of ChatCompletionTokenLogprob, but
|
||||
# expecting list of OpenAITokenLogProb in OpenAIChoice
|
||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||
converted_logprobs = []
|
||||
for token_logprob in choice_delta.logprobs.content:
|
||||
top_logprobs = None
|
||||
if token_logprob.top_logprobs:
|
||||
top_logprobs = [
|
||||
OpenAITopLogProb(
|
||||
token=tlp.token,
|
||||
bytes=tlp.bytes,
|
||||
logprob=tlp.logprob,
|
||||
)
|
||||
for tlp in token_logprob.top_logprobs
|
||||
]
|
||||
converted_logprobs.append(
|
||||
OpenAITokenLogProb(
|
||||
token=token_logprob.token,
|
||||
bytes=token_logprob.bytes,
|
||||
logprob=token_logprob.logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
# Update choice delta with the newly formatted logprobs object
|
||||
choice_delta.logprobs.content = converted_logprobs
|
||||
current_choice_data["logprobs_content_parts"].extend(converted_logprobs)
|
||||
|
||||
# Compute metrics on final chunk
|
||||
if chunk.choices and chunk.choices[0].finish_reason:
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from llama_stack_api import (
|
|||
Order,
|
||||
Prompts,
|
||||
ResponseGuardrailSpec,
|
||||
ResponseItemInclude,
|
||||
Safety,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
|
|
@ -265,7 +266,7 @@ class OpenAIResponsesImpl:
|
|||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
include: list[ResponseItemInclude] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
|
|
@ -331,7 +332,7 @@ class OpenAIResponsesImpl:
|
|||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
include: list[ResponseItemInclude] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
|
|
@ -392,6 +393,7 @@ class OpenAIResponsesImpl:
|
|||
parallel_tool_calls=parallel_tool_calls,
|
||||
max_tool_calls=max_tool_calls,
|
||||
metadata=metadata,
|
||||
include=include,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -445,6 +447,7 @@ class OpenAIResponsesImpl:
|
|||
parallel_tool_calls: bool | None = True,
|
||||
max_tool_calls: int | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
include: list[ResponseItemInclude] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# These should never be None when called from create_openai_response (which sets defaults)
|
||||
# but we assert here to help mypy understand the types
|
||||
|
|
@ -494,6 +497,7 @@ class OpenAIResponsesImpl:
|
|||
instructions=instructions,
|
||||
max_tool_calls=max_tool_calls,
|
||||
metadata=metadata,
|
||||
include=include,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from llama_stack_api import (
|
|||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseContentPartReasoningText,
|
||||
|
|
@ -68,6 +69,7 @@ from llama_stack_api import (
|
|||
OpenAIResponseUsageInputTokensDetails,
|
||||
OpenAIResponseUsageOutputTokensDetails,
|
||||
OpenAIToolMessageParam,
|
||||
ResponseItemInclude,
|
||||
Safety,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
|
|
@ -121,6 +123,7 @@ class StreamingResponseOrchestrator:
|
|||
parallel_tool_calls: bool | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
include: list[ResponseItemInclude] | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
|
@ -139,6 +142,7 @@ class StreamingResponseOrchestrator:
|
|||
# Max number of total calls to built-in tools that can be processed in a response
|
||||
self.max_tool_calls = max_tool_calls
|
||||
self.metadata = metadata
|
||||
self.include = include
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||
|
|
@ -245,6 +249,10 @@ class StreamingResponseOrchestrator:
|
|||
)
|
||||
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||
|
||||
logprobs = (
|
||||
True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else False
|
||||
)
|
||||
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
|
|
@ -256,6 +264,7 @@ class StreamingResponseOrchestrator:
|
|||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
logprobs=logprobs,
|
||||
)
|
||||
completion_result = await self.inference_api.openai_chat_completion(params)
|
||||
|
||||
|
|
@ -577,6 +586,7 @@ class StreamingResponseOrchestrator:
|
|||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
chat_response_logprobs = []
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
|
|
@ -606,6 +616,12 @@ class StreamingResponseOrchestrator:
|
|||
chunk_events: list[OpenAIResponseObjectStream] = []
|
||||
|
||||
for chunk_choice in chunk.choices:
|
||||
# Collect logprobs if present
|
||||
chunk_logprobs = None
|
||||
if chunk_choice.logprobs and chunk_choice.logprobs.content:
|
||||
chunk_logprobs = chunk_choice.logprobs.content
|
||||
chat_response_logprobs.extend(chunk_logprobs)
|
||||
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
# Emit output_item.added for the message on first content
|
||||
|
|
@ -645,6 +661,7 @@ class StreamingResponseOrchestrator:
|
|||
content_index=content_index,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
logprobs=chunk_logprobs,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
|
@ -848,6 +865,7 @@ class StreamingResponseOrchestrator:
|
|||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text=final_text,
|
||||
annotations=[],
|
||||
logprobs=chat_response_logprobs if chat_response_logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -875,6 +893,7 @@ class StreamingResponseOrchestrator:
|
|||
message_item_id=message_item_id,
|
||||
tool_call_item_ids=tool_call_item_ids,
|
||||
content_part_emitted=content_part_emitted,
|
||||
logprobs=OpenAIChoiceLogprobs(content=chat_response_logprobs) if chat_response_logprobs else None,
|
||||
)
|
||||
|
||||
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
|
||||
|
|
@ -896,6 +915,7 @@ class StreamingResponseOrchestrator:
|
|||
message=assistant_message,
|
||||
finish_reason=result.finish_reason,
|
||||
index=0,
|
||||
logprobs=result.logprobs,
|
||||
)
|
||||
],
|
||||
created=result.created,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from llama_stack_api import (
|
|||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseTool,
|
||||
OpenAIResponseToolMCP,
|
||||
OpenAITokenLogProb,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -54,6 +55,7 @@ class ChatCompletionResult:
|
|||
message_item_id: str # For streaming events
|
||||
tool_call_item_ids: dict[int, str] # For streaming events
|
||||
content_part_emitted: bool # Tracking state
|
||||
logprobs: list[OpenAITokenLogProb] | None = None
|
||||
|
||||
@property
|
||||
def content_text(self) -> str:
|
||||
|
|
|
|||
|
|
@ -115,10 +115,17 @@ async def convert_chat_choice_to_response_message(
|
|||
)
|
||||
|
||||
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||
logprobs = choice.logprobs.content if choice.logprobs and choice.logprobs.content else None
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=message_id or f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text=clean_text,
|
||||
annotations=list(annotations),
|
||||
logprobs=logprobs,
|
||||
)
|
||||
],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ __version__ = "0.4.0.dev0"
|
|||
from . import common # noqa: F401
|
||||
|
||||
# Import all public API symbols
|
||||
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec
|
||||
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec, ResponseItemInclude
|
||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||
from .benchmarks import (
|
||||
Benchmark,
|
||||
|
|
@ -764,6 +764,7 @@ __all__ = [
|
|||
"ResponseFormatType",
|
||||
"ResponseGuardrail",
|
||||
"ResponseGuardrailSpec",
|
||||
"ResponseItemInclude",
|
||||
"RouteInfo",
|
||||
"RoutingTable",
|
||||
"RowsDataSource",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -40,6 +41,20 @@ class ResponseGuardrailSpec(BaseModel):
|
|||
ResponseGuardrail = str | ResponseGuardrailSpec
|
||||
|
||||
|
||||
class ResponseItemInclude(StrEnum):
|
||||
"""
|
||||
Specify additional output data to include in the model response.
|
||||
"""
|
||||
|
||||
web_search_call_action_sources = "web_search_call.action.sources"
|
||||
code_interpreter_call_outputs = "code_interpreter_call.outputs"
|
||||
computer_call_output_output_image_url = "computer_call_output.output.image_url"
|
||||
file_search_call_results = "file_search_call.results"
|
||||
message_input_image_image_url = "message.input_image.image_url"
|
||||
message_output_text_logprobs = "message.output_text.logprobs"
|
||||
reasoning_encrypted_content = "reasoning.encrypted_content"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agents(Protocol):
|
||||
"""Agents
|
||||
|
|
@ -80,7 +95,7 @@ class Agents(Protocol):
|
|||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
include: list[ResponseItemInclude] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
guardrails: Annotated[
|
||||
list[ResponseGuardrail] | None,
|
||||
|
|
|
|||
|
|
@ -582,7 +582,7 @@ class OpenAITokenLogProb(BaseModel):
|
|||
token: str
|
||||
bytes: list[int] | None = None
|
||||
logprob: float
|
||||
top_logprobs: list[OpenAITopLogProb]
|
||||
top_logprobs: list[OpenAITopLogProb] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Annotated, Any, Literal
|
|||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack_api.inference import OpenAITokenLogProb
|
||||
from llama_stack_api.schema_utils import json_schema_type, register_schema
|
||||
from llama_stack_api.vector_io import SearchRankingOptions as FileSearchRankingOptions
|
||||
|
||||
|
|
@ -173,6 +174,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
|||
text: str
|
||||
type: Literal["output_text"] = "output_text"
|
||||
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
||||
logprobs: list[OpenAITokenLogProb] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -746,6 +748,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
|||
:param content_index: Index position within the text content
|
||||
:param delta: Incremental text content being added
|
||||
:param item_id: Unique identifier of the output item being updated
|
||||
:param logprobs: (Optional) Token log probability details
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.output_text.delta"
|
||||
|
|
@ -754,6 +757,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
|||
content_index: int
|
||||
delta: str
|
||||
item_id: str
|
||||
logprobs: list[OpenAITokenLogProb] | None = None
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
||||
|
|
@ -944,7 +948,7 @@ class OpenAIResponseContentPartOutputText(BaseModel):
|
|||
type: Literal["output_text"] = "output_text"
|
||||
text: str
|
||||
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
||||
logprobs: list[dict[str, Any]] | None = None
|
||||
logprobs: list[OpenAITokenLogProb] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -12,6 +12,22 @@ from .fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_
|
|||
from .streaming_assertions import StreamingValidator
|
||||
|
||||
|
||||
def provider_from_model(responses_client, text_model_id):
|
||||
models = {m.id: m for m in responses_client.models.list()}
|
||||
models.update(
|
||||
{m.custom_metadata["provider_resource_id"]: m for m in responses_client.models.list() if m.custom_metadata}
|
||||
)
|
||||
provider_id = models[text_model_id].custom_metadata["provider_id"]
|
||||
providers = {p.provider_id: p for p in responses_client.providers.list()}
|
||||
return providers[provider_id]
|
||||
|
||||
|
||||
def skip_if_chat_completions_logprobs_not_supported(responses_client, text_model_id):
|
||||
provider_type = provider_from_model(responses_client, text_model_id).provider_type
|
||||
if provider_type in ("remote::ollama",):
|
||||
pytest.skip(f"Model {text_model_id} hosted by {provider_type} doesn't support /v1/chat/completions logprobs.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", basic_test_cases)
|
||||
def test_response_non_streaming_basic(responses_client, text_model_id, case):
|
||||
response = responses_client.responses.create(
|
||||
|
|
@ -206,3 +222,153 @@ def test_response_non_streaming_multi_turn_image(responses_client, text_model_id
|
|||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn_expected.lower() in output_text
|
||||
|
||||
|
||||
def test_include_logprobs_non_streaming(client_with_models, text_model_id):
|
||||
"""Test logprobs inclusion in responses with the include parameter."""
|
||||
|
||||
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||
|
||||
input = "Which planet do humans live on?"
|
||||
include = ["message.output_text.logprobs"]
|
||||
|
||||
# Create a response without include["message.output_text.logprobs"]
|
||||
response_w_o_logprobs = client_with_models.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify we got one output message and no logprobs
|
||||
assert len(response_w_o_logprobs.output) == 1
|
||||
message_outputs = [output for output in response_w_o_logprobs.output if output.type == "message"]
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
assert message_outputs[0].content[0].logprobs is None, "Expected no logprobs in the returned response"
|
||||
|
||||
# Create a response with include["message.output_text.logprobs"]
|
||||
response_with_logprobs = client_with_models.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
stream=False,
|
||||
include=include,
|
||||
)
|
||||
|
||||
# Verify we got one output message and output message has logprobs
|
||||
assert len(response_with_logprobs.output) == 1
|
||||
message_outputs = [output for output in response_with_logprobs.output if output.type == "message"]
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
assert message_outputs[0].content[0].logprobs is not None, (
|
||||
"Expected logprobs in the returned response, but none were returned"
|
||||
)
|
||||
|
||||
|
||||
def test_include_logprobs_streaming(client_with_models, text_model_id):
|
||||
"""Test logprobs inclusion in responses with the include parameter."""
|
||||
|
||||
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||
|
||||
input = "Which planet do humans live on?"
|
||||
include = ["message.output_text.logprobs"]
|
||||
|
||||
# Create a streaming response with include["message.output_text.logprobs"]
|
||||
stream = client_with_models.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
stream=True,
|
||||
include=include,
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if chunk.type == "response.completed":
|
||||
message_outputs = [output for output in chunk.response.output if output.type == "message"]
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
assert message_outputs[0].content[0].logprobs is not None, (
|
||||
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
|
||||
)
|
||||
elif chunk.type == "response.output_item.done":
|
||||
content = chunk.item.content
|
||||
assert len(content) == 1, f"Expected one content object, got {len(content)}"
|
||||
assert content[0].logprobs is not None, (
|
||||
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
|
||||
)
|
||||
elif chunk.type in ["response.output_text.delta", "response.output_text.done"]:
|
||||
assert chunk.logprobs is not None, (
|
||||
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
|
||||
)
|
||||
elif chunk.type == "response.content_part.done":
|
||||
assert chunk.part.logprobs is None, f"Expected no logprobs in the returned chunk ({chunk.type=})"
|
||||
|
||||
|
||||
def test_include_logprobs_with_web_search(client_with_models, text_model_id):
|
||||
"""Test include logprobs with built-in tool."""
|
||||
|
||||
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||
|
||||
input = "Search for a positive news story from today."
|
||||
include = ["message.output_text.logprobs"]
|
||||
tools = [
|
||||
{
|
||||
"type": "web_search",
|
||||
}
|
||||
]
|
||||
|
||||
# Create a response with built-in tool and include["message.output_text.logprobs"]
|
||||
response = client_with_models.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
stream=False,
|
||||
include=include,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Verify we got one built-in tool call and output message has logprobs
|
||||
assert len(response.output) >= 2
|
||||
assert response.output[0].type == "web_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
message_outputs = [output for output in response.output if output.type == "message"]
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
assert message_outputs[0].content[0].logprobs is not None, (
|
||||
"Expected logprobs in the returned response, but none were returned"
|
||||
)
|
||||
|
||||
|
||||
def test_include_logprobs_with_function_tools(client_with_models, text_model_id):
|
||||
"""Test include logprobs with function tools."""
|
||||
|
||||
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||
|
||||
input = "What is the weather in Paris?"
|
||||
include = ["message.output_text.logprobs"]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information for a specified location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name (e.g., 'New York', 'London')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Create a response with function tool and include["message.output_text.logprobs"]
|
||||
response = client_with_models.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
stream=False,
|
||||
include=include,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Verify we got one function tool call and no logprobs
|
||||
assert len(response.output) == 1
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].name == "get_weather"
|
||||
assert response.output[0].status == "completed"
|
||||
message_outputs = [output for output in response.output if output.type == "message"]
|
||||
assert len(message_outputs) == 0, f"Expected no message output, got {len(message_outputs)}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue