mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 07:22:25 +00:00
feat!: Implement include parameter specifically for adding logprobs in the output message (#4261)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 15s
Python Package Build Test / build (3.12) (push) Successful in 17s
Python Package Build Test / build (3.13) (push) Successful in 18s
Test External API and Providers / test-external (venv) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
UI Tests / ui-tests (22) (push) Successful in 52s
Unit Tests / unit-tests (3.13) (push) Failing after 1m45s
Unit Tests / unit-tests (3.12) (push) Failing after 1m58s
Pre-commit / pre-commit (22) (push) Successful in 3m9s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 4m5s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 15s
Python Package Build Test / build (3.12) (push) Successful in 17s
Python Package Build Test / build (3.13) (push) Successful in 18s
Test External API and Providers / test-external (venv) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
UI Tests / ui-tests (22) (push) Successful in 52s
Unit Tests / unit-tests (3.13) (push) Failing after 1m45s
Unit Tests / unit-tests (3.12) (push) Failing after 1m58s
Pre-commit / pre-commit (22) (push) Successful in 3m9s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 4m5s
# Problem As an Application Developer, I want to use the include parameter with the value message.output_text.logprobs, so that I can receive log probabilities for output tokens to assess the model's confidence in its response. # What does this PR do? - Updates the include parameter in various resource definitions - Updates the inline provider to return logprobs when "message.output_text.logprobs" is passed in the include parameter - Converts the logprobs returned by the inference provider from chat completion format to responses format Closes #[4260](https://github.com/llamastack/llama-stack/issues/4260) ## Test Plan - Created a script to explore OpenAI behavior: https://github.com/s-akhtar-baig/llama-stack-examples/blob/main/responses/src/include.py - Added integration tests and new recordings --------- Co-authored-by: Matthew Farrellee <matt@cs.wisc.edu> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
76e47d811a
commit
805abf573f
26 changed files with 13524 additions and 161 deletions
|
|
@ -25,7 +25,7 @@ __version__ = "0.4.0.dev0"
|
|||
from . import common # noqa: F401
|
||||
|
||||
# Import all public API symbols
|
||||
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec
|
||||
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec, ResponseItemInclude
|
||||
from .batches import (
|
||||
Batches,
|
||||
BatchObject,
|
||||
|
|
@ -798,6 +798,7 @@ __all__ = [
|
|||
"ResponseFormatType",
|
||||
"ResponseGuardrail",
|
||||
"ResponseGuardrailSpec",
|
||||
"ResponseItemInclude",
|
||||
"RouteInfo",
|
||||
"RoutingTable",
|
||||
"RowsDataSource",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -40,6 +41,20 @@ class ResponseGuardrailSpec(BaseModel):
|
|||
ResponseGuardrail = str | ResponseGuardrailSpec
|
||||
|
||||
|
||||
class ResponseItemInclude(StrEnum):
|
||||
"""
|
||||
Specify additional output data to include in the model response.
|
||||
"""
|
||||
|
||||
web_search_call_action_sources = "web_search_call.action.sources"
|
||||
code_interpreter_call_outputs = "code_interpreter_call.outputs"
|
||||
computer_call_output_output_image_url = "computer_call_output.output.image_url"
|
||||
file_search_call_results = "file_search_call.results"
|
||||
message_input_image_image_url = "message.input_image.image_url"
|
||||
message_output_text_logprobs = "message.output_text.logprobs"
|
||||
reasoning_encrypted_content = "reasoning.encrypted_content"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agents(Protocol):
|
||||
"""Agents
|
||||
|
|
@ -80,7 +95,7 @@ class Agents(Protocol):
|
|||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
include: list[ResponseItemInclude] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
guardrails: Annotated[
|
||||
list[ResponseGuardrail] | None,
|
||||
|
|
|
|||
|
|
@ -582,7 +582,7 @@ class OpenAITokenLogProb(BaseModel):
|
|||
token: str
|
||||
bytes: list[int] | None = None
|
||||
logprob: float
|
||||
top_logprobs: list[OpenAITopLogProb]
|
||||
top_logprobs: list[OpenAITopLogProb] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Annotated, Any, Literal
|
|||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack_api.inference import OpenAITokenLogProb
|
||||
from llama_stack_api.schema_utils import json_schema_type, register_schema
|
||||
from llama_stack_api.vector_io import SearchRankingOptions as FileSearchRankingOptions
|
||||
|
||||
|
|
@ -173,6 +174,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
|||
text: str
|
||||
type: Literal["output_text"] = "output_text"
|
||||
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
||||
logprobs: list[OpenAITokenLogProb] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -746,6 +748,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
|||
:param content_index: Index position within the text content
|
||||
:param delta: Incremental text content being added
|
||||
:param item_id: Unique identifier of the output item being updated
|
||||
:param logprobs: (Optional) Token log probability details
|
||||
:param output_index: Index position of the item in the output list
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.output_text.delta"
|
||||
|
|
@ -754,6 +757,7 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
|||
content_index: int
|
||||
delta: str
|
||||
item_id: str
|
||||
logprobs: list[OpenAITokenLogProb] | None = None
|
||||
output_index: int
|
||||
sequence_number: int
|
||||
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
||||
|
|
@ -944,7 +948,7 @@ class OpenAIResponseContentPartOutputText(BaseModel):
|
|||
type: Literal["output_text"] = "output_text"
|
||||
text: str
|
||||
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
||||
logprobs: list[dict[str, Any]] | None = None
|
||||
logprobs: list[OpenAITokenLogProb] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue