mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
Implement include parameter specifically for adding logprobs in the output message
This commit is contained in:
parent
4ff0c25c52
commit
7d6c0aaf11
10 changed files with 255 additions and 8 deletions
|
|
@ -40,6 +40,8 @@ from llama_stack_api import (
|
||||||
OpenAIEmbeddingsRequestWithExtraBody,
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
OpenAITokenLogProb,
|
||||||
|
OpenAITopLogProb,
|
||||||
Order,
|
Order,
|
||||||
RerankResponse,
|
RerankResponse,
|
||||||
RoutingTable,
|
RoutingTable,
|
||||||
|
|
@ -313,8 +315,34 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
if choice_delta.finish_reason:
|
if choice_delta.finish_reason:
|
||||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||||
|
|
||||||
|
# Convert logprobs from chat completion format to responses format
|
||||||
|
# Chat completion returns list of ChatCompletionTokenLogprob, but
|
||||||
|
# expecting list of OpenAITokenLogProb in OpenAIChoice
|
||||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
converted_logprobs = []
|
||||||
|
for token_logprob in choice_delta.logprobs.content:
|
||||||
|
top_logprobs = None
|
||||||
|
if token_logprob.top_logprobs:
|
||||||
|
top_logprobs = [
|
||||||
|
OpenAITopLogProb(
|
||||||
|
token=tlp.token,
|
||||||
|
bytes=tlp.bytes,
|
||||||
|
logprob=tlp.logprob,
|
||||||
|
)
|
||||||
|
for tlp in token_logprob.top_logprobs
|
||||||
|
]
|
||||||
|
converted_logprobs.append(
|
||||||
|
OpenAITokenLogProb(
|
||||||
|
token=token_logprob.token,
|
||||||
|
bytes=token_logprob.bytes,
|
||||||
|
logprob=token_logprob.logprob,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Update choice delta with the newly formatted logprobs object
|
||||||
|
choice_delta.logprobs.content = converted_logprobs
|
||||||
|
current_choice_data["logprobs_content_parts"].extend(converted_logprobs)
|
||||||
|
|
||||||
# Compute metrics on final chunk
|
# Compute metrics on final chunk
|
||||||
if chunk.choices and chunk.choices[0].finish_reason:
|
if chunk.choices and chunk.choices[0].finish_reason:
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ from llama_stack_api import (
|
||||||
Order,
|
Order,
|
||||||
Prompts,
|
Prompts,
|
||||||
ResponseGuardrailSpec,
|
ResponseGuardrailSpec,
|
||||||
|
ResponseItemInclude,
|
||||||
Safety,
|
Safety,
|
||||||
ToolGroups,
|
ToolGroups,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
|
|
@ -265,7 +266,7 @@ class OpenAIResponsesImpl:
|
||||||
response_id: str,
|
response_id: str,
|
||||||
after: str | None = None,
|
after: str | None = None,
|
||||||
before: str | None = None,
|
before: str | None = None,
|
||||||
include: list[str] | None = None,
|
include: list[ResponseItemInclude] | None = None,
|
||||||
limit: int | None = 20,
|
limit: int | None = 20,
|
||||||
order: Order | None = Order.desc,
|
order: Order | None = Order.desc,
|
||||||
) -> ListOpenAIResponseInputItem:
|
) -> ListOpenAIResponseInputItem:
|
||||||
|
|
@ -331,7 +332,7 @@ class OpenAIResponsesImpl:
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
text: OpenAIResponseText | None = None,
|
text: OpenAIResponseText | None = None,
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
include: list[str] | None = None,
|
include: list[ResponseItemInclude] | None = None,
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||||
parallel_tool_calls: bool | None = None,
|
parallel_tool_calls: bool | None = None,
|
||||||
|
|
@ -392,6 +393,7 @@ class OpenAIResponsesImpl:
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
max_tool_calls=max_tool_calls,
|
max_tool_calls=max_tool_calls,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
include=include,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
@ -445,6 +447,7 @@ class OpenAIResponsesImpl:
|
||||||
parallel_tool_calls: bool | None = True,
|
parallel_tool_calls: bool | None = True,
|
||||||
max_tool_calls: int | None = None,
|
max_tool_calls: int | None = None,
|
||||||
metadata: dict[str, str] | None = None,
|
metadata: dict[str, str] | None = None,
|
||||||
|
include: list[ResponseItemInclude] | None = None,
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# These should never be None when called from create_openai_response (which sets defaults)
|
# These should never be None when called from create_openai_response (which sets defaults)
|
||||||
# but we assert here to help mypy understand the types
|
# but we assert here to help mypy understand the types
|
||||||
|
|
@ -494,6 +497,7 @@ class OpenAIResponsesImpl:
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
max_tool_calls=max_tool_calls,
|
max_tool_calls=max_tool_calls,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
include=include,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream the response
|
# Stream the response
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from llama_stack_api import (
|
||||||
OpenAIChatCompletionRequestWithExtraBody,
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
OpenAIChatCompletionToolCall,
|
OpenAIChatCompletionToolCall,
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
|
OpenAIChoiceLogprobs,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseContentPartOutputText,
|
OpenAIResponseContentPartOutputText,
|
||||||
OpenAIResponseContentPartReasoningText,
|
OpenAIResponseContentPartReasoningText,
|
||||||
|
|
@ -68,6 +69,7 @@ from llama_stack_api import (
|
||||||
OpenAIResponseUsageInputTokensDetails,
|
OpenAIResponseUsageInputTokensDetails,
|
||||||
OpenAIResponseUsageOutputTokensDetails,
|
OpenAIResponseUsageOutputTokensDetails,
|
||||||
OpenAIToolMessageParam,
|
OpenAIToolMessageParam,
|
||||||
|
ResponseItemInclude,
|
||||||
Safety,
|
Safety,
|
||||||
WebSearchToolTypes,
|
WebSearchToolTypes,
|
||||||
)
|
)
|
||||||
|
|
@ -121,6 +123,7 @@ class StreamingResponseOrchestrator:
|
||||||
parallel_tool_calls: bool | None = None,
|
parallel_tool_calls: bool | None = None,
|
||||||
max_tool_calls: int | None = None,
|
max_tool_calls: int | None = None,
|
||||||
metadata: dict[str, str] | None = None,
|
metadata: dict[str, str] | None = None,
|
||||||
|
include: list[ResponseItemInclude] | None = None,
|
||||||
):
|
):
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
|
@ -139,6 +142,7 @@ class StreamingResponseOrchestrator:
|
||||||
# Max number of total calls to built-in tools that can be processed in a response
|
# Max number of total calls to built-in tools that can be processed in a response
|
||||||
self.max_tool_calls = max_tool_calls
|
self.max_tool_calls = max_tool_calls
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
self.include = include
|
||||||
self.sequence_number = 0
|
self.sequence_number = 0
|
||||||
# Store MCP tool mapping that gets built during tool processing
|
# Store MCP tool mapping that gets built during tool processing
|
||||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||||
|
|
@ -245,6 +249,10 @@ class StreamingResponseOrchestrator:
|
||||||
)
|
)
|
||||||
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||||
|
|
||||||
|
logprobs = (
|
||||||
|
True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else False
|
||||||
|
)
|
||||||
|
|
||||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||||
model=self.ctx.model,
|
model=self.ctx.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
@ -256,6 +264,7 @@ class StreamingResponseOrchestrator:
|
||||||
stream_options={
|
stream_options={
|
||||||
"include_usage": True,
|
"include_usage": True,
|
||||||
},
|
},
|
||||||
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
completion_result = await self.inference_api.openai_chat_completion(params)
|
completion_result = await self.inference_api.openai_chat_completion(params)
|
||||||
|
|
||||||
|
|
@ -577,6 +586,7 @@ class StreamingResponseOrchestrator:
|
||||||
chunk_created = 0
|
chunk_created = 0
|
||||||
chunk_model = ""
|
chunk_model = ""
|
||||||
chunk_finish_reason = ""
|
chunk_finish_reason = ""
|
||||||
|
chat_response_logprobs = []
|
||||||
|
|
||||||
# Create a placeholder message item for delta events
|
# Create a placeholder message item for delta events
|
||||||
message_item_id = f"msg_{uuid.uuid4()}"
|
message_item_id = f"msg_{uuid.uuid4()}"
|
||||||
|
|
@ -606,6 +616,12 @@ class StreamingResponseOrchestrator:
|
||||||
chunk_events: list[OpenAIResponseObjectStream] = []
|
chunk_events: list[OpenAIResponseObjectStream] = []
|
||||||
|
|
||||||
for chunk_choice in chunk.choices:
|
for chunk_choice in chunk.choices:
|
||||||
|
# Collect logprobs if present
|
||||||
|
chunk_logprobs = None
|
||||||
|
if chunk_choice.logprobs and chunk_choice.logprobs.content:
|
||||||
|
chunk_logprobs = chunk_choice.logprobs.content
|
||||||
|
chat_response_logprobs.extend(chunk_logprobs)
|
||||||
|
|
||||||
# Emit incremental text content as delta events
|
# Emit incremental text content as delta events
|
||||||
if chunk_choice.delta.content:
|
if chunk_choice.delta.content:
|
||||||
# Emit output_item.added for the message on first content
|
# Emit output_item.added for the message on first content
|
||||||
|
|
@ -645,6 +661,7 @@ class StreamingResponseOrchestrator:
|
||||||
content_index=content_index,
|
content_index=content_index,
|
||||||
delta=chunk_choice.delta.content,
|
delta=chunk_choice.delta.content,
|
||||||
item_id=message_item_id,
|
item_id=message_item_id,
|
||||||
|
logprobs=chunk_logprobs,
|
||||||
output_index=message_output_index,
|
output_index=message_output_index,
|
||||||
sequence_number=self.sequence_number,
|
sequence_number=self.sequence_number,
|
||||||
)
|
)
|
||||||
|
|
@ -848,6 +865,7 @@ class StreamingResponseOrchestrator:
|
||||||
OpenAIResponseOutputMessageContentOutputText(
|
OpenAIResponseOutputMessageContentOutputText(
|
||||||
text=final_text,
|
text=final_text,
|
||||||
annotations=[],
|
annotations=[],
|
||||||
|
logprobs=chat_response_logprobs if chat_response_logprobs else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -875,6 +893,7 @@ class StreamingResponseOrchestrator:
|
||||||
message_item_id=message_item_id,
|
message_item_id=message_item_id,
|
||||||
tool_call_item_ids=tool_call_item_ids,
|
tool_call_item_ids=tool_call_item_ids,
|
||||||
content_part_emitted=content_part_emitted,
|
content_part_emitted=content_part_emitted,
|
||||||
|
logprobs=OpenAIChoiceLogprobs(content=chat_response_logprobs) if chat_response_logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
|
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
|
||||||
|
|
@ -896,6 +915,7 @@ class StreamingResponseOrchestrator:
|
||||||
message=assistant_message,
|
message=assistant_message,
|
||||||
finish_reason=result.finish_reason,
|
finish_reason=result.finish_reason,
|
||||||
index=0,
|
index=0,
|
||||||
|
logprobs=result.logprobs,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
created=result.created,
|
created=result.created,
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from llama_stack_api import (
|
||||||
OpenAIResponseOutputMessageMCPListTools,
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
OpenAIResponseTool,
|
OpenAIResponseTool,
|
||||||
OpenAIResponseToolMCP,
|
OpenAIResponseToolMCP,
|
||||||
|
OpenAITokenLogProb,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,6 +55,7 @@ class ChatCompletionResult:
|
||||||
message_item_id: str # For streaming events
|
message_item_id: str # For streaming events
|
||||||
tool_call_item_ids: dict[int, str] # For streaming events
|
tool_call_item_ids: dict[int, str] # For streaming events
|
||||||
content_part_emitted: bool # Tracking state
|
content_part_emitted: bool # Tracking state
|
||||||
|
logprobs: list[OpenAITokenLogProb] | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content_text(self) -> str:
|
def content_text(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -115,10 +115,17 @@ async def convert_chat_choice_to_response_message(
|
||||||
)
|
)
|
||||||
|
|
||||||
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||||
|
logprobs = choice.logprobs.content if choice.logprobs and choice.logprobs.content else None
|
||||||
|
|
||||||
return OpenAIResponseMessage(
|
return OpenAIResponseMessage(
|
||||||
id=message_id or f"msg_{uuid.uuid4()}",
|
id=message_id or f"msg_{uuid.uuid4()}",
|
||||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
|
content=[
|
||||||
|
OpenAIResponseOutputMessageContentOutputText(
|
||||||
|
text=clean_text,
|
||||||
|
annotations=list(annotations),
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
],
|
||||||
status="completed",
|
status="completed",
|
||||||
role="assistant",
|
role="assistant",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ __version__ = "0.4.0.dev0"
|
||||||
from . import common # noqa: F401
|
from . import common # noqa: F401
|
||||||
|
|
||||||
# Import all public API symbols
|
# Import all public API symbols
|
||||||
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec
|
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec, ResponseItemInclude
|
||||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||||
from .benchmarks import (
|
from .benchmarks import (
|
||||||
Benchmark,
|
Benchmark,
|
||||||
|
|
@ -764,6 +764,7 @@ __all__ = [
|
||||||
"ResponseFormatType",
|
"ResponseFormatType",
|
||||||
"ResponseGuardrail",
|
"ResponseGuardrail",
|
||||||
"ResponseGuardrailSpec",
|
"ResponseGuardrailSpec",
|
||||||
|
"ResponseItemInclude",
|
||||||
"RouteInfo",
|
"RouteInfo",
|
||||||
"RoutingTable",
|
"RoutingTable",
|
||||||
"RowsDataSource",
|
"RowsDataSource",
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from enum import StrEnum
|
||||||
from typing import Annotated, Protocol, runtime_checkable
|
from typing import Annotated, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -40,6 +41,20 @@ class ResponseGuardrailSpec(BaseModel):
|
||||||
ResponseGuardrail = str | ResponseGuardrailSpec
|
ResponseGuardrail = str | ResponseGuardrailSpec
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseItemInclude(StrEnum):
|
||||||
|
"""
|
||||||
|
Specify additional output data to include in the model response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
web_search_call_action_sources = "web_search_call.action.sources"
|
||||||
|
code_interpreter_call_outputs = "code_interpreter_call.outputs"
|
||||||
|
computer_call_output_output_image_url = "computer_call_output.output.image_url"
|
||||||
|
file_search_call_results = "file_search_call.results"
|
||||||
|
message_input_image_image_url = "message.input_image.image_url"
|
||||||
|
message_output_text_logprobs = "message.output_text.logprobs"
|
||||||
|
reasoning_encrypted_content = "reasoning.encrypted_content"
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Agents(Protocol):
|
class Agents(Protocol):
|
||||||
"""Agents
|
"""Agents
|
||||||
|
|
@ -80,7 +95,7 @@ class Agents(Protocol):
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
text: OpenAIResponseText | None = None,
|
text: OpenAIResponseText | None = None,
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
include: list[str] | None = None,
|
include: list[ResponseItemInclude] | None = None,
|
||||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||||
guardrails: Annotated[
|
guardrails: Annotated[
|
||||||
list[ResponseGuardrail] | None,
|
list[ResponseGuardrail] | None,
|
||||||
|
|
|
||||||
|
|
@ -582,7 +582,7 @@ class OpenAITokenLogProb(BaseModel):
|
||||||
token: str
|
token: str
|
||||||
bytes: list[int] | None = None
|
bytes: list[int] | None = None
|
||||||
logprob: float
|
logprob: float
|
||||||
top_logprobs: list[OpenAITopLogProb]
|
top_logprobs: list[OpenAITopLogProb] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from typing import Annotated, Any, Literal
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from llama_stack_api.inference import OpenAITokenLogProb
|
||||||
from llama_stack_api.schema_utils import json_schema_type, register_schema
|
from llama_stack_api.schema_utils import json_schema_type, register_schema
|
||||||
from llama_stack_api.vector_io import SearchRankingOptions as FileSearchRankingOptions
|
from llama_stack_api.vector_io import SearchRankingOptions as FileSearchRankingOptions
|
||||||
|
|
||||||
|
|
@ -173,6 +174,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
type: Literal["output_text"] = "output_text"
|
type: Literal["output_text"] = "output_text"
|
||||||
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
||||||
|
logprobs: list[OpenAITokenLogProb] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
@ -746,6 +748,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
||||||
:param content_index: Index position within the text content
|
:param content_index: Index position within the text content
|
||||||
:param delta: Incremental text content being added
|
:param delta: Incremental text content being added
|
||||||
:param item_id: Unique identifier of the output item being updated
|
:param item_id: Unique identifier of the output item being updated
|
||||||
|
:param logprobs: (Optional) Token log probability details
|
||||||
:param output_index: Index position of the item in the output list
|
:param output_index: Index position of the item in the output list
|
||||||
:param sequence_number: Sequential number for ordering streaming events
|
:param sequence_number: Sequential number for ordering streaming events
|
||||||
:param type: Event type identifier, always "response.output_text.delta"
|
:param type: Event type identifier, always "response.output_text.delta"
|
||||||
|
|
@ -754,6 +757,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
||||||
content_index: int
|
content_index: int
|
||||||
delta: str
|
delta: str
|
||||||
item_id: str
|
item_id: str
|
||||||
|
logprobs: list[OpenAITokenLogProb] | None = None
|
||||||
output_index: int
|
output_index: int
|
||||||
sequence_number: int
|
sequence_number: int
|
||||||
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
||||||
|
|
@ -944,7 +948,7 @@ class OpenAIResponseContentPartOutputText(BaseModel):
|
||||||
type: Literal["output_text"] = "output_text"
|
type: Literal["output_text"] = "output_text"
|
||||||
text: str
|
text: str
|
||||||
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
||||||
logprobs: list[dict[str, Any]] | None = None
|
logprobs: list[OpenAITokenLogProb] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,22 @@ from .fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_
|
||||||
from .streaming_assertions import StreamingValidator
|
from .streaming_assertions import StreamingValidator
|
||||||
|
|
||||||
|
|
||||||
|
def provider_from_model(responses_client, text_model_id):
|
||||||
|
models = {m.id: m for m in responses_client.models.list()}
|
||||||
|
models.update(
|
||||||
|
{m.custom_metadata["provider_resource_id"]: m for m in responses_client.models.list() if m.custom_metadata}
|
||||||
|
)
|
||||||
|
provider_id = models[text_model_id].custom_metadata["provider_id"]
|
||||||
|
providers = {p.provider_id: p for p in responses_client.providers.list()}
|
||||||
|
return providers[provider_id]
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_chat_completions_logprobs_not_supported(responses_client, text_model_id):
|
||||||
|
provider_type = provider_from_model(responses_client, text_model_id).provider_type
|
||||||
|
if provider_type in ("remote::ollama",):
|
||||||
|
pytest.skip(f"Model {text_model_id} hosted by {provider_type} doesn't support /v1/chat/completions logprobs.")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("case", basic_test_cases)
|
@pytest.mark.parametrize("case", basic_test_cases)
|
||||||
def test_response_non_streaming_basic(responses_client, text_model_id, case):
|
def test_response_non_streaming_basic(responses_client, text_model_id, case):
|
||||||
response = responses_client.responses.create(
|
response = responses_client.responses.create(
|
||||||
|
|
@ -206,3 +222,153 @@ def test_response_non_streaming_multi_turn_image(responses_client, text_model_id
|
||||||
previous_response_id = response.id
|
previous_response_id = response.id
|
||||||
output_text = response.output_text.lower()
|
output_text = response.output_text.lower()
|
||||||
assert turn_expected.lower() in output_text
|
assert turn_expected.lower() in output_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_include_logprobs_non_streaming(client_with_models, text_model_id):
|
||||||
|
"""Test logprobs inclusion in responses with the include parameter."""
|
||||||
|
|
||||||
|
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||||
|
|
||||||
|
input = "Which planet do humans live on?"
|
||||||
|
include = ["message.output_text.logprobs"]
|
||||||
|
|
||||||
|
# Create a response without include["message.output_text.logprobs"]
|
||||||
|
response_w_o_logprobs = client_with_models.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=input,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got one output message and no logprobs
|
||||||
|
assert len(response_w_o_logprobs.output) == 1
|
||||||
|
message_outputs = [output for output in response_w_o_logprobs.output if output.type == "message"]
|
||||||
|
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||||
|
assert message_outputs[0].content[0].logprobs is None, "Expected no logprobs in the returned response"
|
||||||
|
|
||||||
|
# Create a response with include["message.output_text.logprobs"]
|
||||||
|
response_with_logprobs = client_with_models.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=input,
|
||||||
|
stream=False,
|
||||||
|
include=include,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got one output message and output message has logprobs
|
||||||
|
assert len(response_with_logprobs.output) == 1
|
||||||
|
message_outputs = [output for output in response_with_logprobs.output if output.type == "message"]
|
||||||
|
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||||
|
assert message_outputs[0].content[0].logprobs is not None, (
|
||||||
|
"Expected logprobs in the returned response, but none were returned"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_include_logprobs_streaming(client_with_models, text_model_id):
|
||||||
|
"""Test logprobs inclusion in responses with the include parameter."""
|
||||||
|
|
||||||
|
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||||
|
|
||||||
|
input = "Which planet do humans live on?"
|
||||||
|
include = ["message.output_text.logprobs"]
|
||||||
|
|
||||||
|
# Create a streaming response with include["message.output_text.logprobs"]
|
||||||
|
stream = client_with_models.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=input,
|
||||||
|
stream=True,
|
||||||
|
include=include,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.type == "response.completed":
|
||||||
|
message_outputs = [output for output in chunk.response.output if output.type == "message"]
|
||||||
|
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||||
|
assert message_outputs[0].content[0].logprobs is not None, (
|
||||||
|
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
|
||||||
|
)
|
||||||
|
elif chunk.type == "response.output_item.done":
|
||||||
|
content = chunk.item.content
|
||||||
|
assert len(content) == 1, f"Expected one content object, got {len(content)}"
|
||||||
|
assert content[0].logprobs is not None, (
|
||||||
|
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
|
||||||
|
)
|
||||||
|
elif chunk.type in ["response.output_text.delta", "response.output_text.done"]:
|
||||||
|
assert chunk.logprobs is not None, (
|
||||||
|
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
|
||||||
|
)
|
||||||
|
elif chunk.type == "response.content_part.done":
|
||||||
|
assert chunk.part.logprobs is None, f"Expected no logprobs in the returned chunk ({chunk.type=})"
|
||||||
|
|
||||||
|
|
||||||
|
def test_include_logprobs_with_web_search(client_with_models, text_model_id):
|
||||||
|
"""Test include logprobs with built-in tool."""
|
||||||
|
|
||||||
|
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||||
|
|
||||||
|
input = "Search for a positive news story from today."
|
||||||
|
include = ["message.output_text.logprobs"]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "web_search",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a response with built-in tool and include["message.output_text.logprobs"]
|
||||||
|
response = client_with_models.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=input,
|
||||||
|
stream=False,
|
||||||
|
include=include,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got one built-in tool call and output message has logprobs
|
||||||
|
assert len(response.output) >= 2
|
||||||
|
assert response.output[0].type == "web_search_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
message_outputs = [output for output in response.output if output.type == "message"]
|
||||||
|
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||||
|
assert message_outputs[0].content[0].logprobs is not None, (
|
||||||
|
"Expected logprobs in the returned response, but none were returned"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_include_logprobs_with_function_tools(client_with_models, text_model_id):
|
||||||
|
"""Test include logprobs with function tools."""
|
||||||
|
|
||||||
|
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||||
|
|
||||||
|
input = "What is the weather in Paris?"
|
||||||
|
include = ["message.output_text.logprobs"]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather information for a specified location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city name (e.g., 'New York', 'London')",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a response with function tool and include["message.output_text.logprobs"]
|
||||||
|
response = client_with_models.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=input,
|
||||||
|
stream=False,
|
||||||
|
include=include,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got one function tool call and no logprobs
|
||||||
|
assert len(response.output) == 1
|
||||||
|
assert response.output[0].type == "function_call"
|
||||||
|
assert response.output[0].name == "get_weather"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
message_outputs = [output for output in response.output if output.type == "message"]
|
||||||
|
assert len(message_outputs) == 0, f"Expected no message output, got {len(message_outputs)}"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue