Implement include parameter specifically for adding logprobs in the output message

This commit is contained in:
Shabana Baig 2025-11-30 11:53:57 -05:00
parent 4ff0c25c52
commit 7d6c0aaf11
10 changed files with 255 additions and 8 deletions

View file

@ -40,6 +40,8 @@ from llama_stack_api import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAITokenLogProb,
OpenAITopLogProb,
Order,
RerankResponse,
RoutingTable,
@ -313,8 +315,34 @@ class InferenceRouter(Inference):
)
if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason
# Convert logprobs from chat completion format to responses format
# Chat completion returns list of ChatCompletionTokenLogprob, but
# expecting list of OpenAITokenLogProb in OpenAIChoice
if choice_delta.logprobs and choice_delta.logprobs.content:
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
converted_logprobs = []
for token_logprob in choice_delta.logprobs.content:
top_logprobs = None
if token_logprob.top_logprobs:
top_logprobs = [
OpenAITopLogProb(
token=tlp.token,
bytes=tlp.bytes,
logprob=tlp.logprob,
)
for tlp in token_logprob.top_logprobs
]
converted_logprobs.append(
OpenAITokenLogProb(
token=token_logprob.token,
bytes=token_logprob.bytes,
logprob=token_logprob.logprob,
top_logprobs=top_logprobs,
)
)
# Update choice delta with the newly formatted logprobs object
choice_delta.logprobs.content = converted_logprobs
current_choice_data["logprobs_content_parts"].extend(converted_logprobs)
# Compute metrics on final chunk
if chunk.choices and chunk.choices[0].finish_reason:

View file

@ -43,6 +43,7 @@ from llama_stack_api import (
Order,
Prompts,
ResponseGuardrailSpec,
ResponseItemInclude,
Safety,
ToolGroups,
ToolRuntime,
@ -265,7 +266,7 @@ class OpenAIResponsesImpl:
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
include: list[ResponseItemInclude] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
@ -331,7 +332,7 @@ class OpenAIResponsesImpl:
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
include: list[ResponseItemInclude] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[str | ResponseGuardrailSpec] | None = None,
parallel_tool_calls: bool | None = None,
@ -392,6 +393,7 @@ class OpenAIResponsesImpl:
parallel_tool_calls=parallel_tool_calls,
max_tool_calls=max_tool_calls,
metadata=metadata,
include=include,
)
if stream:
@ -445,6 +447,7 @@ class OpenAIResponsesImpl:
parallel_tool_calls: bool | None = True,
max_tool_calls: int | None = None,
metadata: dict[str, str] | None = None,
include: list[ResponseItemInclude] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types
@ -494,6 +497,7 @@ class OpenAIResponsesImpl:
instructions=instructions,
max_tool_calls=max_tool_calls,
metadata=metadata,
include=include,
)
# Stream the response

View file

@ -24,6 +24,7 @@ from llama_stack_api import (
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAIMessageParam,
OpenAIResponseContentPartOutputText,
OpenAIResponseContentPartReasoningText,
@ -68,6 +69,7 @@ from llama_stack_api import (
OpenAIResponseUsageInputTokensDetails,
OpenAIResponseUsageOutputTokensDetails,
OpenAIToolMessageParam,
ResponseItemInclude,
Safety,
WebSearchToolTypes,
)
@ -121,6 +123,7 @@ class StreamingResponseOrchestrator:
parallel_tool_calls: bool | None = None,
max_tool_calls: int | None = None,
metadata: dict[str, str] | None = None,
include: list[ResponseItemInclude] | None = None,
):
self.inference_api = inference_api
self.ctx = ctx
@ -139,6 +142,7 @@ class StreamingResponseOrchestrator:
# Max number of total calls to built-in tools that can be processed in a response
self.max_tool_calls = max_tool_calls
self.metadata = metadata
self.include = include
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
@ -245,6 +249,10 @@ class StreamingResponseOrchestrator:
)
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
logprobs = (
True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else False
)
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model,
messages=messages,
@ -256,6 +264,7 @@ class StreamingResponseOrchestrator:
stream_options={
"include_usage": True,
},
logprobs=logprobs,
)
completion_result = await self.inference_api.openai_chat_completion(params)
@ -577,6 +586,7 @@ class StreamingResponseOrchestrator:
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
chat_response_logprobs = []
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
@ -606,6 +616,12 @@ class StreamingResponseOrchestrator:
chunk_events: list[OpenAIResponseObjectStream] = []
for chunk_choice in chunk.choices:
# Collect logprobs if present
chunk_logprobs = None
if chunk_choice.logprobs and chunk_choice.logprobs.content:
chunk_logprobs = chunk_choice.logprobs.content
chat_response_logprobs.extend(chunk_logprobs)
# Emit incremental text content as delta events
if chunk_choice.delta.content:
# Emit output_item.added for the message on first content
@ -645,6 +661,7 @@ class StreamingResponseOrchestrator:
content_index=content_index,
delta=chunk_choice.delta.content,
item_id=message_item_id,
logprobs=chunk_logprobs,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
@ -848,6 +865,7 @@ class StreamingResponseOrchestrator:
OpenAIResponseOutputMessageContentOutputText(
text=final_text,
annotations=[],
logprobs=chat_response_logprobs if chat_response_logprobs else None,
)
)
@ -875,6 +893,7 @@ class StreamingResponseOrchestrator:
message_item_id=message_item_id,
tool_call_item_ids=tool_call_item_ids,
content_part_emitted=content_part_emitted,
logprobs=OpenAIChoiceLogprobs(content=chat_response_logprobs) if chat_response_logprobs else None,
)
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
@ -896,6 +915,7 @@ class StreamingResponseOrchestrator:
message=assistant_message,
finish_reason=result.finish_reason,
index=0,
logprobs=result.logprobs,
)
],
created=result.created,

View file

@ -28,6 +28,7 @@ from llama_stack_api import (
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseTool,
OpenAIResponseToolMCP,
OpenAITokenLogProb,
)
@ -54,6 +55,7 @@ class ChatCompletionResult:
message_item_id: str # For streaming events
tool_call_item_ids: dict[int, str] # For streaming events
content_part_emitted: bool # Tracking state
logprobs: list[OpenAITokenLogProb] | None = None
@property
def content_text(self) -> str:

View file

@ -115,10 +115,17 @@ async def convert_chat_choice_to_response_message(
)
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
logprobs = choice.logprobs.content if choice.logprobs and choice.logprobs.content else None
return OpenAIResponseMessage(
id=message_id or f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
content=[
OpenAIResponseOutputMessageContentOutputText(
text=clean_text,
annotations=list(annotations),
logprobs=logprobs,
)
],
status="completed",
role="assistant",
)

View file

@ -25,7 +25,7 @@ __version__ = "0.4.0.dev0"
from . import common # noqa: F401
# Import all public API symbols
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec, ResponseItemInclude
from .batches import Batches, BatchObject, ListBatchesResponse
from .benchmarks import (
Benchmark,
@ -764,6 +764,7 @@ __all__ = [
"ResponseFormatType",
"ResponseGuardrail",
"ResponseGuardrailSpec",
"ResponseItemInclude",
"RouteInfo",
"RoutingTable",
"RowsDataSource",

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from collections.abc import AsyncIterator
from enum import StrEnum
from typing import Annotated, Protocol, runtime_checkable
from pydantic import BaseModel
@ -40,6 +41,20 @@ class ResponseGuardrailSpec(BaseModel):
ResponseGuardrail = str | ResponseGuardrailSpec
class ResponseItemInclude(StrEnum):
"""
Specify additional output data to include in the model response.
"""
web_search_call_action_sources = "web_search_call.action.sources"
code_interpreter_call_outputs = "code_interpreter_call.outputs"
computer_call_output_output_image_url = "computer_call_output.output.image_url"
file_search_call_results = "file_search_call.results"
message_input_image_image_url = "message.input_image.image_url"
message_output_text_logprobs = "message.output_text.logprobs"
reasoning_encrypted_content = "reasoning.encrypted_content"
@runtime_checkable
class Agents(Protocol):
"""Agents
@ -80,7 +95,7 @@ class Agents(Protocol):
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
include: list[ResponseItemInclude] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
guardrails: Annotated[
list[ResponseGuardrail] | None,

View file

@ -582,7 +582,7 @@ class OpenAITokenLogProb(BaseModel):
token: str
bytes: list[int] | None = None
logprob: float
top_logprobs: list[OpenAITopLogProb]
top_logprobs: list[OpenAITopLogProb] | None = None
@json_schema_type

View file

@ -10,6 +10,7 @@ from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, model_validator
from typing_extensions import TypedDict
from llama_stack_api.inference import OpenAITokenLogProb
from llama_stack_api.schema_utils import json_schema_type, register_schema
from llama_stack_api.vector_io import SearchRankingOptions as FileSearchRankingOptions
@ -173,6 +174,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
text: str
type: Literal["output_text"] = "output_text"
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
logprobs: list[OpenAITokenLogProb] | None = None
@json_schema_type
@ -746,6 +748,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
:param content_index: Index position within the text content
:param delta: Incremental text content being added
:param item_id: Unique identifier of the output item being updated
:param logprobs: (Optional) Token log probability details
:param output_index: Index position of the item in the output list
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.output_text.delta"
@ -754,6 +757,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
content_index: int
delta: str
item_id: str
logprobs: list[OpenAITokenLogProb] | None = None
output_index: int
sequence_number: int
type: Literal["response.output_text.delta"] = "response.output_text.delta"
@ -944,7 +948,7 @@ class OpenAIResponseContentPartOutputText(BaseModel):
type: Literal["output_text"] = "output_text"
text: str
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
logprobs: list[dict[str, Any]] | None = None
logprobs: list[OpenAITokenLogProb] | None = None
@json_schema_type

View file

@ -12,6 +12,22 @@ from .fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_
from .streaming_assertions import StreamingValidator
def provider_from_model(responses_client, text_model_id):
models = {m.id: m for m in responses_client.models.list()}
models.update(
{m.custom_metadata["provider_resource_id"]: m for m in responses_client.models.list() if m.custom_metadata}
)
provider_id = models[text_model_id].custom_metadata["provider_id"]
providers = {p.provider_id: p for p in responses_client.providers.list()}
return providers[provider_id]
def skip_if_chat_completions_logprobs_not_supported(responses_client, text_model_id):
provider_type = provider_from_model(responses_client, text_model_id).provider_type
if provider_type in ("remote::ollama",):
pytest.skip(f"Model {text_model_id} hosted by {provider_type} doesn't support /v1/chat/completions logprobs.")
@pytest.mark.parametrize("case", basic_test_cases)
def test_response_non_streaming_basic(responses_client, text_model_id, case):
response = responses_client.responses.create(
@ -206,3 +222,153 @@ def test_response_non_streaming_multi_turn_image(responses_client, text_model_id
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn_expected.lower() in output_text
def test_include_logprobs_non_streaming(client_with_models, text_model_id):
"""Test logprobs inclusion in responses with the include parameter."""
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
input = "Which planet do humans live on?"
include = ["message.output_text.logprobs"]
# Create a response without include["message.output_text.logprobs"]
response_w_o_logprobs = client_with_models.responses.create(
model=text_model_id,
input=input,
stream=False,
)
# Verify we got one output message and no logprobs
assert len(response_w_o_logprobs.output) == 1
message_outputs = [output for output in response_w_o_logprobs.output if output.type == "message"]
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
assert message_outputs[0].content[0].logprobs is None, "Expected no logprobs in the returned response"
# Create a response with include["message.output_text.logprobs"]
response_with_logprobs = client_with_models.responses.create(
model=text_model_id,
input=input,
stream=False,
include=include,
)
# Verify we got one output message and output message has logprobs
assert len(response_with_logprobs.output) == 1
message_outputs = [output for output in response_with_logprobs.output if output.type == "message"]
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
assert message_outputs[0].content[0].logprobs is not None, (
"Expected logprobs in the returned response, but none were returned"
)
def test_include_logprobs_streaming(client_with_models, text_model_id):
"""Test logprobs inclusion in responses with the include parameter."""
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
input = "Which planet do humans live on?"
include = ["message.output_text.logprobs"]
# Create a streaming response with include["message.output_text.logprobs"]
stream = client_with_models.responses.create(
model=text_model_id,
input=input,
stream=True,
include=include,
)
for chunk in stream:
if chunk.type == "response.completed":
message_outputs = [output for output in chunk.response.output if output.type == "message"]
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
assert message_outputs[0].content[0].logprobs is not None, (
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
)
elif chunk.type == "response.output_item.done":
content = chunk.item.content
assert len(content) == 1, f"Expected one content object, got {len(content)}"
assert content[0].logprobs is not None, (
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
)
elif chunk.type in ["response.output_text.delta", "response.output_text.done"]:
assert chunk.logprobs is not None, (
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
)
elif chunk.type == "response.content_part.done":
assert chunk.part.logprobs is None, f"Expected no logprobs in the returned chunk ({chunk.type=})"
def test_include_logprobs_with_web_search(client_with_models, text_model_id):
"""Test include logprobs with built-in tool."""
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
input = "Search for a positive news story from today."
include = ["message.output_text.logprobs"]
tools = [
{
"type": "web_search",
}
]
# Create a response with built-in tool and include["message.output_text.logprobs"]
response = client_with_models.responses.create(
model=text_model_id,
input=input,
stream=False,
include=include,
tools=tools,
)
# Verify we got one built-in tool call and output message has logprobs
assert len(response.output) >= 2
assert response.output[0].type == "web_search_call"
assert response.output[0].status == "completed"
message_outputs = [output for output in response.output if output.type == "message"]
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
assert message_outputs[0].content[0].logprobs is not None, (
"Expected logprobs in the returned response, but none were returned"
)
def test_include_logprobs_with_function_tools(client_with_models, text_model_id):
"""Test include logprobs with function tools."""
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
input = "What is the weather in Paris?"
include = ["message.output_text.logprobs"]
tools = [
{
"type": "function",
"name": "get_weather",
"description": "Get weather information for a specified location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name (e.g., 'New York', 'London')",
},
},
},
},
]
# Create a response with function tool and include["message.output_text.logprobs"]
response = client_with_models.responses.create(
model=text_model_id,
input=input,
stream=False,
include=include,
tools=tools,
)
# Verify we got one function tool call and no logprobs
assert len(response.output) == 1
assert response.output[0].type == "function_call"
assert response.output[0].name == "get_weather"
assert response.output[0].status == "completed"
message_outputs = [output for output in response.output if output.type == "message"]
assert len(message_outputs) == 0, f"Expected no message output, got {len(message_outputs)}"