From 7d6c0aaf11bdb59eba92dafb2d65b7ea1f7d20ea Mon Sep 17 00:00:00 2001 From: Shabana Baig <43451943+s-akhtar-baig@users.noreply.github.com> Date: Sun, 30 Nov 2025 11:53:57 -0500 Subject: [PATCH] Implement include parameter specifically for adding logprobs in the output message --- src/llama_stack/core/routers/inference.py | 30 +++- .../responses/openai_responses.py | 8 +- .../meta_reference/responses/streaming.py | 20 +++ .../agents/meta_reference/responses/types.py | 2 + .../agents/meta_reference/responses/utils.py | 9 +- src/llama_stack_api/__init__.py | 3 +- src/llama_stack_api/agents.py | 17 +- src/llama_stack_api/inference.py | 2 +- src/llama_stack_api/openai_responses.py | 6 +- .../responses/test_basic_responses.py | 166 ++++++++++++++++++ 10 files changed, 255 insertions(+), 8 deletions(-) diff --git a/src/llama_stack/core/routers/inference.py b/src/llama_stack/core/routers/inference.py index 8a7ffaa5f..4e58ecdaf 100644 --- a/src/llama_stack/core/routers/inference.py +++ b/src/llama_stack/core/routers/inference.py @@ -40,6 +40,8 @@ from llama_stack_api import ( OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsResponse, OpenAIMessageParam, + OpenAITokenLogProb, + OpenAITopLogProb, Order, RerankResponse, RoutingTable, @@ -313,8 +315,34 @@ class InferenceRouter(Inference): ) if choice_delta.finish_reason: current_choice_data["finish_reason"] = choice_delta.finish_reason + + # Convert logprobs from chat completion format to responses format + # Chat completion returns list of ChatCompletionTokenLogprob, but + # expecting list of OpenAITokenLogProb in OpenAIChoice if choice_delta.logprobs and choice_delta.logprobs.content: - current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) + converted_logprobs = [] + for token_logprob in choice_delta.logprobs.content: + top_logprobs = None + if token_logprob.top_logprobs: + top_logprobs = [ + OpenAITopLogProb( + token=tlp.token, + bytes=tlp.bytes, + logprob=tlp.logprob, + ) + for tlp in token_logprob.top_logprobs + ] + converted_logprobs.append( + OpenAITokenLogProb( + token=token_logprob.token, + bytes=token_logprob.bytes, + logprob=token_logprob.logprob, + top_logprobs=top_logprobs, + ) + ) + # Update choice delta with the newly formatted logprobs object + choice_delta.logprobs.content = converted_logprobs + current_choice_data["logprobs_content_parts"].extend(converted_logprobs) # Compute metrics on final chunk if chunk.choices and chunk.choices[0].finish_reason: diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 9cf30908c..68e6e3eb2 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -43,6 +43,7 @@ from llama_stack_api import ( Order, Prompts, ResponseGuardrailSpec, + ResponseItemInclude, Safety, ToolGroups, ToolRuntime, @@ -265,7 +266,7 @@ class OpenAIResponsesImpl: response_id: str, after: str | None = None, before: str | None = None, - include: list[str] | None = None, + include: list[ResponseItemInclude] | None = None, limit: int | None = 20, order: Order | None = Order.desc, ) -> ListOpenAIResponseInputItem: @@ -331,7 +332,7 @@ class OpenAIResponsesImpl: temperature: float | None = None, text: OpenAIResponseText | None = None, tools: list[OpenAIResponseInputTool] | None = None, - include: list[str] | None = None, + include: list[ResponseItemInclude] | None = None, max_infer_iters: int | None = 10, guardrails: list[str | ResponseGuardrailSpec] | None = None, parallel_tool_calls: bool | None = None, @@ -392,6 +393,7 @@ class OpenAIResponsesImpl: parallel_tool_calls=parallel_tool_calls, max_tool_calls=max_tool_calls, metadata=metadata, + include=include, ) if stream: @@ -445,6 +447,7 @@ class OpenAIResponsesImpl: parallel_tool_calls: bool | None = True, max_tool_calls: int | None = None, metadata: dict[str, str] | None = None, + include: list[ResponseItemInclude] | None = None, ) -> AsyncIterator[OpenAIResponseObjectStream]: # These should never be None when called from create_openai_response (which sets defaults) # but we assert here to help mypy understand the types @@ -494,6 +497,7 @@ class OpenAIResponsesImpl: instructions=instructions, max_tool_calls=max_tool_calls, metadata=metadata, + include=include, ) # Stream the response diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index c778d65e7..00feb0c0f 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -24,6 +24,7 @@ from llama_stack_api import ( OpenAIChatCompletionRequestWithExtraBody, OpenAIChatCompletionToolCall, OpenAIChoice, + OpenAIChoiceLogprobs, OpenAIMessageParam, OpenAIResponseContentPartOutputText, OpenAIResponseContentPartReasoningText, @@ -68,6 +69,7 @@ from llama_stack_api import ( OpenAIResponseUsageInputTokensDetails, OpenAIResponseUsageOutputTokensDetails, OpenAIToolMessageParam, + ResponseItemInclude, Safety, WebSearchToolTypes, ) @@ -121,6 +123,7 @@ class StreamingResponseOrchestrator: parallel_tool_calls: bool | None = None, max_tool_calls: int | None = None, metadata: dict[str, str] | None = None, + include: list[ResponseItemInclude] | None = None, ): self.inference_api = inference_api self.ctx = ctx @@ -139,6 +142,7 @@ class StreamingResponseOrchestrator: # Max number of total calls to built-in tools that can be processed in a response self.max_tool_calls = max_tool_calls self.metadata = metadata + self.include = include self.sequence_number = 0 # Store MCP tool mapping that gets built during tool processing self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ( @@ -245,6 +249,10 @@ class StreamingResponseOrchestrator: ) logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}") + logprobs = ( + True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else False + ) + params = OpenAIChatCompletionRequestWithExtraBody( model=self.ctx.model, messages=messages, @@ -256,6 +264,7 @@ class StreamingResponseOrchestrator: stream_options={ "include_usage": True, }, + logprobs=logprobs, ) completion_result = await self.inference_api.openai_chat_completion(params) @@ -577,6 +586,7 @@ class StreamingResponseOrchestrator: chunk_created = 0 chunk_model = "" chunk_finish_reason = "" + chat_response_logprobs = [] # Create a placeholder message item for delta events message_item_id = f"msg_{uuid.uuid4()}" @@ -606,6 +616,12 @@ class StreamingResponseOrchestrator: chunk_events: list[OpenAIResponseObjectStream] = [] for chunk_choice in chunk.choices: + # Collect logprobs if present + chunk_logprobs = None + if chunk_choice.logprobs and chunk_choice.logprobs.content: + chunk_logprobs = chunk_choice.logprobs.content + chat_response_logprobs.extend(chunk_logprobs) + # Emit incremental text content as delta events if chunk_choice.delta.content: # Emit output_item.added for the message on first content @@ -645,6 +661,7 @@ class StreamingResponseOrchestrator: content_index=content_index, delta=chunk_choice.delta.content, item_id=message_item_id, + logprobs=chunk_logprobs, output_index=message_output_index, sequence_number=self.sequence_number, ) @@ -848,6 +865,7 @@ class StreamingResponseOrchestrator: OpenAIResponseOutputMessageContentOutputText( text=final_text, annotations=[], + logprobs=chat_response_logprobs if chat_response_logprobs else None, ) ) @@ -875,6 +893,7 @@ class StreamingResponseOrchestrator: message_item_id=message_item_id, tool_call_item_ids=tool_call_item_ids, content_part_emitted=content_part_emitted, + logprobs=OpenAIChoiceLogprobs(content=chat_response_logprobs) if chat_response_logprobs else None, ) def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion: @@ -896,6 +915,7 @@ class StreamingResponseOrchestrator: message=assistant_message, finish_reason=result.finish_reason, index=0, + logprobs=result.logprobs, ) ], created=result.created, diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py index f6efcee22..5e52db6b7 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -28,6 +28,7 @@ from llama_stack_api import ( OpenAIResponseOutputMessageMCPListTools, OpenAIResponseTool, OpenAIResponseToolMCP, + OpenAITokenLogProb, ) @@ -54,6 +55,7 @@ class ChatCompletionResult: message_item_id: str # For streaming events tool_call_item_ids: dict[int, str] # For streaming events content_part_emitted: bool # Tracking state + logprobs: list[OpenAITokenLogProb] | None = None @property def content_text(self) -> str: diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 7bbf6bd30..59abe2d69 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -115,10 +115,17 @@ async def convert_chat_choice_to_response_message( ) annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {}) + logprobs = choice.logprobs.content if choice.logprobs and choice.logprobs.content else None return OpenAIResponseMessage( id=message_id or f"msg_{uuid.uuid4()}", - content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))], + content=[ + OpenAIResponseOutputMessageContentOutputText( + text=clean_text, + annotations=list(annotations), + logprobs=logprobs, + ) + ], status="completed", role="assistant", ) diff --git a/src/llama_stack_api/__init__.py b/src/llama_stack_api/__init__.py index b6fe2fd23..ddaa73150 100644 --- a/src/llama_stack_api/__init__.py +++ b/src/llama_stack_api/__init__.py @@ -25,7 +25,7 @@ __version__ = "0.4.0.dev0" from . import common # noqa: F401 # Import all public API symbols -from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec +from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec, ResponseItemInclude from .batches import Batches, BatchObject, ListBatchesResponse from .benchmarks import ( Benchmark, @@ -764,6 +764,7 @@ __all__ = [ "ResponseFormatType", "ResponseGuardrail", "ResponseGuardrailSpec", + "ResponseItemInclude", "RouteInfo", "RoutingTable", "RowsDataSource", diff --git a/src/llama_stack_api/agents.py b/src/llama_stack_api/agents.py index 8d3b489e1..cb0de33a1 100644 --- a/src/llama_stack_api/agents.py +++ b/src/llama_stack_api/agents.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from collections.abc import AsyncIterator +from enum import StrEnum from typing import Annotated, Protocol, runtime_checkable from pydantic import BaseModel @@ -40,6 +41,20 @@ class ResponseGuardrailSpec(BaseModel): ResponseGuardrail = str | ResponseGuardrailSpec +class ResponseItemInclude(StrEnum): + """ + Specify additional output data to include in the model response. + """ + + web_search_call_action_sources = "web_search_call.action.sources" + code_interpreter_call_outputs = "code_interpreter_call.outputs" + computer_call_output_output_image_url = "computer_call_output.output.image_url" + file_search_call_results = "file_search_call.results" + message_input_image_image_url = "message.input_image.image_url" + message_output_text_logprobs = "message.output_text.logprobs" + reasoning_encrypted_content = "reasoning.encrypted_content" + + @runtime_checkable class Agents(Protocol): """Agents @@ -80,7 +95,7 @@ class Agents(Protocol): temperature: float | None = None, text: OpenAIResponseText | None = None, tools: list[OpenAIResponseInputTool] | None = None, - include: list[str] | None = None, + include: list[ResponseItemInclude] | None = None, max_infer_iters: int | None = 10, # this is an extension to the OpenAI API guardrails: Annotated[ list[ResponseGuardrail] | None, diff --git a/src/llama_stack_api/inference.py b/src/llama_stack_api/inference.py index 4a169486a..7ff3f1803 100644 --- a/src/llama_stack_api/inference.py +++ b/src/llama_stack_api/inference.py @@ -582,7 +582,7 @@ class OpenAITokenLogProb(BaseModel): token: str bytes: list[int] | None = None logprob: float - top_logprobs: list[OpenAITopLogProb] + top_logprobs: list[OpenAITopLogProb] | None = None @json_schema_type diff --git a/src/llama_stack_api/openai_responses.py b/src/llama_stack_api/openai_responses.py index 177d2314a..f038bf77b 100644 --- a/src/llama_stack_api/openai_responses.py +++ b/src/llama_stack_api/openai_responses.py @@ -10,6 +10,7 @@ from typing import Annotated, Any, Literal from pydantic import BaseModel, Field, model_validator from typing_extensions import TypedDict +from llama_stack_api.inference import OpenAITokenLogProb from llama_stack_api.schema_utils import json_schema_type, register_schema from llama_stack_api.vector_io import SearchRankingOptions as FileSearchRankingOptions @@ -173,6 +174,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel): text: str type: Literal["output_text"] = "output_text" annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list) + logprobs: list[OpenAITokenLogProb] | None = None @json_schema_type @@ -746,6 +748,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel): :param content_index: Index position within the text content :param delta: Incremental text content being added :param item_id: Unique identifier of the output item being updated + :param logprobs: (Optional) Token log probability details :param output_index: Index position of the item in the output list :param sequence_number: Sequential number for ordering streaming events :param type: Event type identifier, always "response.output_text.delta" @@ -754,6 +757,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel): content_index: int delta: str item_id: str + logprobs: list[OpenAITokenLogProb] | None = None output_index: int sequence_number: int type: Literal["response.output_text.delta"] = "response.output_text.delta" @@ -944,7 +948,7 @@ class OpenAIResponseContentPartOutputText(BaseModel): type: Literal["output_text"] = "output_text" text: str annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list) - logprobs: list[dict[str, Any]] | None = None + logprobs: list[OpenAITokenLogProb] | None = None @json_schema_type diff --git a/tests/integration/responses/test_basic_responses.py b/tests/integration/responses/test_basic_responses.py index d72a43375..7a3ca285a 100644 --- a/tests/integration/responses/test_basic_responses.py +++ b/tests/integration/responses/test_basic_responses.py @@ -12,6 +12,22 @@ from .fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_ from .streaming_assertions import StreamingValidator +def provider_from_model(responses_client, text_model_id): + models = {m.id: m for m in responses_client.models.list()} + models.update( + {m.custom_metadata["provider_resource_id"]: m for m in responses_client.models.list() if m.custom_metadata} + ) + provider_id = models[text_model_id].custom_metadata["provider_id"] + providers = {p.provider_id: p for p in responses_client.providers.list()} + return providers[provider_id] + + +def skip_if_chat_completions_logprobs_not_supported(responses_client, text_model_id): + provider_type = provider_from_model(responses_client, text_model_id).provider_type + if provider_type in ("remote::ollama",): + pytest.skip(f"Model {text_model_id} hosted by {provider_type} doesn't support /v1/chat/completions logprobs.") + + @pytest.mark.parametrize("case", basic_test_cases) def test_response_non_streaming_basic(responses_client, text_model_id, case): response = responses_client.responses.create( @@ -206,3 +222,153 @@ def test_response_non_streaming_multi_turn_image(responses_client, text_model_id previous_response_id = response.id output_text = response.output_text.lower() assert turn_expected.lower() in output_text + + +def test_include_logprobs_non_streaming(client_with_models, text_model_id): + """Test logprobs inclusion in responses with the include parameter.""" + + skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id) + + input = "Which planet do humans live on?" + include = ["message.output_text.logprobs"] + + # Create a response without include["message.output_text.logprobs"] + response_w_o_logprobs = client_with_models.responses.create( + model=text_model_id, + input=input, + stream=False, + ) + + # Verify we got one output message and no logprobs + assert len(response_w_o_logprobs.output) == 1 + message_outputs = [output for output in response_w_o_logprobs.output if output.type == "message"] + assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}" + assert message_outputs[0].content[0].logprobs is None, "Expected no logprobs in the returned response" + + # Create a response with include["message.output_text.logprobs"] + response_with_logprobs = client_with_models.responses.create( + model=text_model_id, + input=input, + stream=False, + include=include, + ) + + # Verify we got one output message and output message has logprobs + assert len(response_with_logprobs.output) == 1 + message_outputs = [output for output in response_with_logprobs.output if output.type == "message"] + assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}" + assert message_outputs[0].content[0].logprobs is not None, ( + "Expected logprobs in the returned response, but none were returned" + ) + + +def test_include_logprobs_streaming(client_with_models, text_model_id): + """Test logprobs inclusion in responses with the include parameter.""" + + skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id) + + input = "Which planet do humans live on?" + include = ["message.output_text.logprobs"] + + # Create a streaming response with include["message.output_text.logprobs"] + stream = client_with_models.responses.create( + model=text_model_id, + input=input, + stream=True, + include=include, + ) + + for chunk in stream: + if chunk.type == "response.completed": + message_outputs = [output for output in chunk.response.output if output.type == "message"] + assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}" + assert message_outputs[0].content[0].logprobs is not None, ( + f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned" + ) + elif chunk.type == "response.output_item.done": + content = chunk.item.content + assert len(content) == 1, f"Expected one content object, got {len(content)}" + assert content[0].logprobs is not None, ( + f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned" + ) + elif chunk.type in ["response.output_text.delta", "response.output_text.done"]: + assert chunk.logprobs is not None, ( + f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned" + ) + elif chunk.type == "response.content_part.done": + assert chunk.part.logprobs is None, f"Expected no logprobs in the returned chunk ({chunk.type=})" + + +def test_include_logprobs_with_web_search(client_with_models, text_model_id): + """Test include logprobs with built-in tool.""" + + skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id) + + input = "Search for a positive news story from today." + include = ["message.output_text.logprobs"] + tools = [ + { + "type": "web_search", + } + ] + + # Create a response with built-in tool and include["message.output_text.logprobs"] + response = client_with_models.responses.create( + model=text_model_id, + input=input, + stream=False, + include=include, + tools=tools, + ) + + # Verify we got one built-in tool call and output message has logprobs + assert len(response.output) >= 2 + assert response.output[0].type == "web_search_call" + assert response.output[0].status == "completed" + message_outputs = [output for output in response.output if output.type == "message"] + assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}" + assert message_outputs[0].content[0].logprobs is not None, ( + "Expected logprobs in the returned response, but none were returned" + ) + + +def test_include_logprobs_with_function_tools(client_with_models, text_model_id): + """Test include logprobs with function tools.""" + + skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id) + + input = "What is the weather in Paris?" + include = ["message.output_text.logprobs"] + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get weather information for a specified location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name (e.g., 'New York', 'London')", + }, + }, + }, + }, + ] + + # Create a response with function tool and include["message.output_text.logprobs"] + response = client_with_models.responses.create( + model=text_model_id, + input=input, + stream=False, + include=include, + tools=tools, + ) + + # Verify we got one function tool call and no logprobs + assert len(response.output) == 1 + assert response.output[0].type == "function_call" + assert response.output[0].name == "get_weather" + assert response.output[0].status == "completed" + message_outputs = [output for output in response.output if output.type == "message"] + assert len(message_outputs) == 0, f"Expected no message output, got {len(message_outputs)}"