Implement include parameter specifically for adding logprobs in the output message

This commit is contained in:
Shabana Baig 2025-11-30 11:53:57 -05:00
parent 4ff0c25c52
commit 7d6c0aaf11
10 changed files with 255 additions and 8 deletions

View file

@ -40,6 +40,8 @@ from llama_stack_api import (
OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIMessageParam, OpenAIMessageParam,
OpenAITokenLogProb,
OpenAITopLogProb,
Order, Order,
RerankResponse, RerankResponse,
RoutingTable, RoutingTable,
@ -313,8 +315,34 @@ class InferenceRouter(Inference):
) )
if choice_delta.finish_reason: if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason current_choice_data["finish_reason"] = choice_delta.finish_reason
# Convert logprobs from chat completion format to responses format
# Chat completion returns list of ChatCompletionTokenLogprob, but
# expecting list of OpenAITokenLogProb in OpenAIChoice
if choice_delta.logprobs and choice_delta.logprobs.content: if choice_delta.logprobs and choice_delta.logprobs.content:
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) converted_logprobs = []
for token_logprob in choice_delta.logprobs.content:
top_logprobs = None
if token_logprob.top_logprobs:
top_logprobs = [
OpenAITopLogProb(
token=tlp.token,
bytes=tlp.bytes,
logprob=tlp.logprob,
)
for tlp in token_logprob.top_logprobs
]
converted_logprobs.append(
OpenAITokenLogProb(
token=token_logprob.token,
bytes=token_logprob.bytes,
logprob=token_logprob.logprob,
top_logprobs=top_logprobs,
)
)
# Update choice delta with the newly formatted logprobs object
choice_delta.logprobs.content = converted_logprobs
current_choice_data["logprobs_content_parts"].extend(converted_logprobs)
# Compute metrics on final chunk # Compute metrics on final chunk
if chunk.choices and chunk.choices[0].finish_reason: if chunk.choices and chunk.choices[0].finish_reason:

View file

@ -43,6 +43,7 @@ from llama_stack_api import (
Order, Order,
Prompts, Prompts,
ResponseGuardrailSpec, ResponseGuardrailSpec,
ResponseItemInclude,
Safety, Safety,
ToolGroups, ToolGroups,
ToolRuntime, ToolRuntime,
@ -265,7 +266,7 @@ class OpenAIResponsesImpl:
response_id: str, response_id: str,
after: str | None = None, after: str | None = None,
before: str | None = None, before: str | None = None,
include: list[str] | None = None, include: list[ResponseItemInclude] | None = None,
limit: int | None = 20, limit: int | None = 20,
order: Order | None = Order.desc, order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem: ) -> ListOpenAIResponseInputItem:
@ -331,7 +332,7 @@ class OpenAIResponsesImpl:
temperature: float | None = None, temperature: float | None = None,
text: OpenAIResponseText | None = None, text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None, include: list[ResponseItemInclude] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
guardrails: list[str | ResponseGuardrailSpec] | None = None, guardrails: list[str | ResponseGuardrailSpec] | None = None,
parallel_tool_calls: bool | None = None, parallel_tool_calls: bool | None = None,
@ -392,6 +393,7 @@ class OpenAIResponsesImpl:
parallel_tool_calls=parallel_tool_calls, parallel_tool_calls=parallel_tool_calls,
max_tool_calls=max_tool_calls, max_tool_calls=max_tool_calls,
metadata=metadata, metadata=metadata,
include=include,
) )
if stream: if stream:
@ -445,6 +447,7 @@ class OpenAIResponsesImpl:
parallel_tool_calls: bool | None = True, parallel_tool_calls: bool | None = True,
max_tool_calls: int | None = None, max_tool_calls: int | None = None,
metadata: dict[str, str] | None = None, metadata: dict[str, str] | None = None,
include: list[ResponseItemInclude] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]: ) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults) # These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types # but we assert here to help mypy understand the types
@ -494,6 +497,7 @@ class OpenAIResponsesImpl:
instructions=instructions, instructions=instructions,
max_tool_calls=max_tool_calls, max_tool_calls=max_tool_calls,
metadata=metadata, metadata=metadata,
include=include,
) )
# Stream the response # Stream the response

View file

@ -24,6 +24,7 @@ from llama_stack_api import (
OpenAIChatCompletionRequestWithExtraBody, OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall, OpenAIChatCompletionToolCall,
OpenAIChoice, OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseContentPartOutputText, OpenAIResponseContentPartOutputText,
OpenAIResponseContentPartReasoningText, OpenAIResponseContentPartReasoningText,
@ -68,6 +69,7 @@ from llama_stack_api import (
OpenAIResponseUsageInputTokensDetails, OpenAIResponseUsageInputTokensDetails,
OpenAIResponseUsageOutputTokensDetails, OpenAIResponseUsageOutputTokensDetails,
OpenAIToolMessageParam, OpenAIToolMessageParam,
ResponseItemInclude,
Safety, Safety,
WebSearchToolTypes, WebSearchToolTypes,
) )
@ -121,6 +123,7 @@ class StreamingResponseOrchestrator:
parallel_tool_calls: bool | None = None, parallel_tool_calls: bool | None = None,
max_tool_calls: int | None = None, max_tool_calls: int | None = None,
metadata: dict[str, str] | None = None, metadata: dict[str, str] | None = None,
include: list[ResponseItemInclude] | None = None,
): ):
self.inference_api = inference_api self.inference_api = inference_api
self.ctx = ctx self.ctx = ctx
@ -139,6 +142,7 @@ class StreamingResponseOrchestrator:
# Max number of total calls to built-in tools that can be processed in a response # Max number of total calls to built-in tools that can be processed in a response
self.max_tool_calls = max_tool_calls self.max_tool_calls = max_tool_calls
self.metadata = metadata self.metadata = metadata
self.include = include
self.sequence_number = 0 self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing # Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ( self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
@ -245,6 +249,10 @@ class StreamingResponseOrchestrator:
) )
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}") logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
logprobs = (
True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else False
)
params = OpenAIChatCompletionRequestWithExtraBody( params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model, model=self.ctx.model,
messages=messages, messages=messages,
@ -256,6 +264,7 @@ class StreamingResponseOrchestrator:
stream_options={ stream_options={
"include_usage": True, "include_usage": True,
}, },
logprobs=logprobs,
) )
completion_result = await self.inference_api.openai_chat_completion(params) completion_result = await self.inference_api.openai_chat_completion(params)
@ -577,6 +586,7 @@ class StreamingResponseOrchestrator:
chunk_created = 0 chunk_created = 0
chunk_model = "" chunk_model = ""
chunk_finish_reason = "" chunk_finish_reason = ""
chat_response_logprobs = []
# Create a placeholder message item for delta events # Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}" message_item_id = f"msg_{uuid.uuid4()}"
@ -606,6 +616,12 @@ class StreamingResponseOrchestrator:
chunk_events: list[OpenAIResponseObjectStream] = [] chunk_events: list[OpenAIResponseObjectStream] = []
for chunk_choice in chunk.choices: for chunk_choice in chunk.choices:
# Collect logprobs if present
chunk_logprobs = None
if chunk_choice.logprobs and chunk_choice.logprobs.content:
chunk_logprobs = chunk_choice.logprobs.content
chat_response_logprobs.extend(chunk_logprobs)
# Emit incremental text content as delta events # Emit incremental text content as delta events
if chunk_choice.delta.content: if chunk_choice.delta.content:
# Emit output_item.added for the message on first content # Emit output_item.added for the message on first content
@ -645,6 +661,7 @@ class StreamingResponseOrchestrator:
content_index=content_index, content_index=content_index,
delta=chunk_choice.delta.content, delta=chunk_choice.delta.content,
item_id=message_item_id, item_id=message_item_id,
logprobs=chunk_logprobs,
output_index=message_output_index, output_index=message_output_index,
sequence_number=self.sequence_number, sequence_number=self.sequence_number,
) )
@ -848,6 +865,7 @@ class StreamingResponseOrchestrator:
OpenAIResponseOutputMessageContentOutputText( OpenAIResponseOutputMessageContentOutputText(
text=final_text, text=final_text,
annotations=[], annotations=[],
logprobs=chat_response_logprobs if chat_response_logprobs else None,
) )
) )
@ -875,6 +893,7 @@ class StreamingResponseOrchestrator:
message_item_id=message_item_id, message_item_id=message_item_id,
tool_call_item_ids=tool_call_item_ids, tool_call_item_ids=tool_call_item_ids,
content_part_emitted=content_part_emitted, content_part_emitted=content_part_emitted,
logprobs=OpenAIChoiceLogprobs(content=chat_response_logprobs) if chat_response_logprobs else None,
) )
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion: def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
@ -896,6 +915,7 @@ class StreamingResponseOrchestrator:
message=assistant_message, message=assistant_message,
finish_reason=result.finish_reason, finish_reason=result.finish_reason,
index=0, index=0,
logprobs=result.logprobs,
) )
], ],
created=result.created, created=result.created,

View file

@ -28,6 +28,7 @@ from llama_stack_api import (
OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseTool, OpenAIResponseTool,
OpenAIResponseToolMCP, OpenAIResponseToolMCP,
OpenAITokenLogProb,
) )
@ -54,6 +55,7 @@ class ChatCompletionResult:
message_item_id: str # For streaming events message_item_id: str # For streaming events
tool_call_item_ids: dict[int, str] # For streaming events tool_call_item_ids: dict[int, str] # For streaming events
content_part_emitted: bool # Tracking state content_part_emitted: bool # Tracking state
logprobs: list[OpenAITokenLogProb] | None = None
@property @property
def content_text(self) -> str: def content_text(self) -> str:

View file

@ -115,10 +115,17 @@ async def convert_chat_choice_to_response_message(
) )
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {}) annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
logprobs = choice.logprobs.content if choice.logprobs and choice.logprobs.content else None
return OpenAIResponseMessage( return OpenAIResponseMessage(
id=message_id or f"msg_{uuid.uuid4()}", id=message_id or f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))], content=[
OpenAIResponseOutputMessageContentOutputText(
text=clean_text,
annotations=list(annotations),
logprobs=logprobs,
)
],
status="completed", status="completed",
role="assistant", role="assistant",
) )

View file

@ -25,7 +25,7 @@ __version__ = "0.4.0.dev0"
from . import common # noqa: F401 from . import common # noqa: F401
# Import all public API symbols # Import all public API symbols
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec, ResponseItemInclude
from .batches import Batches, BatchObject, ListBatchesResponse from .batches import Batches, BatchObject, ListBatchesResponse
from .benchmarks import ( from .benchmarks import (
Benchmark, Benchmark,
@ -764,6 +764,7 @@ __all__ = [
"ResponseFormatType", "ResponseFormatType",
"ResponseGuardrail", "ResponseGuardrail",
"ResponseGuardrailSpec", "ResponseGuardrailSpec",
"ResponseItemInclude",
"RouteInfo", "RouteInfo",
"RoutingTable", "RoutingTable",
"RowsDataSource", "RowsDataSource",

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from enum import StrEnum
from typing import Annotated, Protocol, runtime_checkable from typing import Annotated, Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
@ -40,6 +41,20 @@ class ResponseGuardrailSpec(BaseModel):
ResponseGuardrail = str | ResponseGuardrailSpec ResponseGuardrail = str | ResponseGuardrailSpec
class ResponseItemInclude(StrEnum):
"""
Specify additional output data to include in the model response.
"""
web_search_call_action_sources = "web_search_call.action.sources"
code_interpreter_call_outputs = "code_interpreter_call.outputs"
computer_call_output_output_image_url = "computer_call_output.output.image_url"
file_search_call_results = "file_search_call.results"
message_input_image_image_url = "message.input_image.image_url"
message_output_text_logprobs = "message.output_text.logprobs"
reasoning_encrypted_content = "reasoning.encrypted_content"
@runtime_checkable @runtime_checkable
class Agents(Protocol): class Agents(Protocol):
"""Agents """Agents
@ -80,7 +95,7 @@ class Agents(Protocol):
temperature: float | None = None, temperature: float | None = None,
text: OpenAIResponseText | None = None, text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None, include: list[ResponseItemInclude] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
guardrails: Annotated[ guardrails: Annotated[
list[ResponseGuardrail] | None, list[ResponseGuardrail] | None,

View file

@ -582,7 +582,7 @@ class OpenAITokenLogProb(BaseModel):
token: str token: str
bytes: list[int] | None = None bytes: list[int] | None = None
logprob: float logprob: float
top_logprobs: list[OpenAITopLogProb] top_logprobs: list[OpenAITopLogProb] | None = None
@json_schema_type @json_schema_type

View file

@ -10,6 +10,7 @@ from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing_extensions import TypedDict from typing_extensions import TypedDict
from llama_stack_api.inference import OpenAITokenLogProb
from llama_stack_api.schema_utils import json_schema_type, register_schema from llama_stack_api.schema_utils import json_schema_type, register_schema
from llama_stack_api.vector_io import SearchRankingOptions as FileSearchRankingOptions from llama_stack_api.vector_io import SearchRankingOptions as FileSearchRankingOptions
@ -173,6 +174,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
text: str text: str
type: Literal["output_text"] = "output_text" type: Literal["output_text"] = "output_text"
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list) annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
logprobs: list[OpenAITokenLogProb] | None = None
@json_schema_type @json_schema_type
@ -746,6 +748,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
:param content_index: Index position within the text content :param content_index: Index position within the text content
:param delta: Incremental text content being added :param delta: Incremental text content being added
:param item_id: Unique identifier of the output item being updated :param item_id: Unique identifier of the output item being updated
:param logprobs: (Optional) Token log probability details
:param output_index: Index position of the item in the output list :param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events :param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.output_text.delta" :param type: Event type identifier, always "response.output_text.delta"
@ -754,6 +757,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
content_index: int content_index: int
delta: str delta: str
item_id: str item_id: str
logprobs: list[OpenAITokenLogProb] | None = None
output_index: int output_index: int
sequence_number: int sequence_number: int
type: Literal["response.output_text.delta"] = "response.output_text.delta" type: Literal["response.output_text.delta"] = "response.output_text.delta"
@ -944,7 +948,7 @@ class OpenAIResponseContentPartOutputText(BaseModel):
type: Literal["output_text"] = "output_text" type: Literal["output_text"] = "output_text"
text: str text: str
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list) annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
logprobs: list[dict[str, Any]] | None = None logprobs: list[OpenAITokenLogProb] | None = None
@json_schema_type @json_schema_type

View file

@ -12,6 +12,22 @@ from .fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_
from .streaming_assertions import StreamingValidator from .streaming_assertions import StreamingValidator
def provider_from_model(responses_client, text_model_id):
models = {m.id: m for m in responses_client.models.list()}
models.update(
{m.custom_metadata["provider_resource_id"]: m for m in responses_client.models.list() if m.custom_metadata}
)
provider_id = models[text_model_id].custom_metadata["provider_id"]
providers = {p.provider_id: p for p in responses_client.providers.list()}
return providers[provider_id]
def skip_if_chat_completions_logprobs_not_supported(responses_client, text_model_id):
provider_type = provider_from_model(responses_client, text_model_id).provider_type
if provider_type in ("remote::ollama",):
pytest.skip(f"Model {text_model_id} hosted by {provider_type} doesn't support /v1/chat/completions logprobs.")
@pytest.mark.parametrize("case", basic_test_cases) @pytest.mark.parametrize("case", basic_test_cases)
def test_response_non_streaming_basic(responses_client, text_model_id, case): def test_response_non_streaming_basic(responses_client, text_model_id, case):
response = responses_client.responses.create( response = responses_client.responses.create(
@ -206,3 +222,153 @@ def test_response_non_streaming_multi_turn_image(responses_client, text_model_id
previous_response_id = response.id previous_response_id = response.id
output_text = response.output_text.lower() output_text = response.output_text.lower()
assert turn_expected.lower() in output_text assert turn_expected.lower() in output_text
def test_include_logprobs_non_streaming(client_with_models, text_model_id):
"""Test logprobs inclusion in responses with the include parameter."""
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
input = "Which planet do humans live on?"
include = ["message.output_text.logprobs"]
# Create a response without include["message.output_text.logprobs"]
response_w_o_logprobs = client_with_models.responses.create(
model=text_model_id,
input=input,
stream=False,
)
# Verify we got one output message and no logprobs
assert len(response_w_o_logprobs.output) == 1
message_outputs = [output for output in response_w_o_logprobs.output if output.type == "message"]
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
assert message_outputs[0].content[0].logprobs is None, "Expected no logprobs in the returned response"
# Create a response with include["message.output_text.logprobs"]
response_with_logprobs = client_with_models.responses.create(
model=text_model_id,
input=input,
stream=False,
include=include,
)
# Verify we got one output message and output message has logprobs
assert len(response_with_logprobs.output) == 1
message_outputs = [output for output in response_with_logprobs.output if output.type == "message"]
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
assert message_outputs[0].content[0].logprobs is not None, (
"Expected logprobs in the returned response, but none were returned"
)
def test_include_logprobs_streaming(client_with_models, text_model_id):
"""Test logprobs inclusion in responses with the include parameter."""
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
input = "Which planet do humans live on?"
include = ["message.output_text.logprobs"]
# Create a streaming response with include["message.output_text.logprobs"]
stream = client_with_models.responses.create(
model=text_model_id,
input=input,
stream=True,
include=include,
)
for chunk in stream:
if chunk.type == "response.completed":
message_outputs = [output for output in chunk.response.output if output.type == "message"]
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
assert message_outputs[0].content[0].logprobs is not None, (
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
)
elif chunk.type == "response.output_item.done":
content = chunk.item.content
assert len(content) == 1, f"Expected one content object, got {len(content)}"
assert content[0].logprobs is not None, (
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
)
elif chunk.type in ["response.output_text.delta", "response.output_text.done"]:
assert chunk.logprobs is not None, (
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
)
elif chunk.type == "response.content_part.done":
assert chunk.part.logprobs is None, f"Expected no logprobs in the returned chunk ({chunk.type=})"
def test_include_logprobs_with_web_search(client_with_models, text_model_id):
"""Test include logprobs with built-in tool."""
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
input = "Search for a positive news story from today."
include = ["message.output_text.logprobs"]
tools = [
{
"type": "web_search",
}
]
# Create a response with built-in tool and include["message.output_text.logprobs"]
response = client_with_models.responses.create(
model=text_model_id,
input=input,
stream=False,
include=include,
tools=tools,
)
# Verify we got one built-in tool call and output message has logprobs
assert len(response.output) >= 2
assert response.output[0].type == "web_search_call"
assert response.output[0].status == "completed"
message_outputs = [output for output in response.output if output.type == "message"]
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
assert message_outputs[0].content[0].logprobs is not None, (
"Expected logprobs in the returned response, but none were returned"
)
def test_include_logprobs_with_function_tools(client_with_models, text_model_id):
"""Test include logprobs with function tools."""
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
input = "What is the weather in Paris?"
include = ["message.output_text.logprobs"]
tools = [
{
"type": "function",
"name": "get_weather",
"description": "Get weather information for a specified location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name (e.g., 'New York', 'London')",
},
},
},
},
]
# Create a response with function tool and include["message.output_text.logprobs"]
response = client_with_models.responses.create(
model=text_model_id,
input=input,
stream=False,
include=include,
tools=tools,
)
# Verify we got one function tool call and no logprobs
assert len(response.output) == 1
assert response.output[0].type == "function_call"
assert response.output[0].name == "get_weather"
assert response.output[0].status == "completed"
message_outputs = [output for output in response.output if output.type == "message"]
assert len(message_outputs) == 0, f"Expected no message output, got {len(message_outputs)}"