mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 09:42:35 +00:00
feat!: Implement include parameter specifically for adding logprobs in the output message (#4261)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 15s
Python Package Build Test / build (3.12) (push) Successful in 17s
Python Package Build Test / build (3.13) (push) Successful in 18s
Test External API and Providers / test-external (venv) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
UI Tests / ui-tests (22) (push) Successful in 52s
Unit Tests / unit-tests (3.13) (push) Failing after 1m45s
Unit Tests / unit-tests (3.12) (push) Failing after 1m58s
Pre-commit / pre-commit (22) (push) Successful in 3m9s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 4m5s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 15s
Python Package Build Test / build (3.12) (push) Successful in 17s
Python Package Build Test / build (3.13) (push) Successful in 18s
Test External API and Providers / test-external (venv) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
UI Tests / ui-tests (22) (push) Successful in 52s
Unit Tests / unit-tests (3.13) (push) Failing after 1m45s
Unit Tests / unit-tests (3.12) (push) Failing after 1m58s
Pre-commit / pre-commit (22) (push) Successful in 3m9s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 4m5s
# Problem As an Application Developer, I want to use the include parameter with the value message.output_text.logprobs, so that I can receive log probabilities for output tokens to assess the model's confidence in its response. # What does this PR do? - Updates the include parameter in various resource definitions - Updates the inline provider to return logprobs when "message.output_text.logprobs" is passed in the include parameter - Converts the logprobs returned by the inference provider from chat completion format to responses format Closes #[4260](https://github.com/llamastack/llama-stack/issues/4260) ## Test Plan - Created a script to explore OpenAI behavior: https://github.com/s-akhtar-baig/llama-stack-examples/blob/main/responses/src/include.py - Added integration tests and new recordings --------- Co-authored-by: Matthew Farrellee <matt@cs.wisc.edu> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
76e47d811a
commit
805abf573f
26 changed files with 13524 additions and 161 deletions
|
|
@ -12,6 +12,22 @@ from .fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_
|
|||
from .streaming_assertions import StreamingValidator
|
||||
|
||||
|
||||
def provider_from_model(client_with_models, text_model_id):
|
||||
models = {m.id: m for m in client_with_models.models.list()}
|
||||
models.update(
|
||||
{m.custom_metadata["provider_resource_id"]: m for m in client_with_models.models.list() if m.custom_metadata}
|
||||
)
|
||||
provider_id = models[text_model_id].custom_metadata["provider_id"]
|
||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||
return providers[provider_id]
|
||||
|
||||
|
||||
def skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id):
|
||||
provider_type = provider_from_model(client_with_models, text_model_id).provider_type
|
||||
if provider_type in ("remote::ollama",):
|
||||
pytest.skip(f"Model {text_model_id} hosted by {provider_type} doesn't support /v1/chat/completions logprobs.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", basic_test_cases)
|
||||
def test_response_non_streaming_basic(responses_client, text_model_id, case):
|
||||
response = responses_client.responses.create(
|
||||
|
|
@ -206,3 +222,153 @@ def test_response_non_streaming_multi_turn_image(responses_client, text_model_id
|
|||
previous_response_id = response.id
|
||||
output_text = response.output_text.lower()
|
||||
assert turn_expected.lower() in output_text
|
||||
|
||||
|
||||
def test_include_logprobs_non_streaming(client_with_models, text_model_id):
|
||||
"""Test logprobs inclusion in responses with the include parameter."""
|
||||
|
||||
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||
|
||||
input = "Which planet do humans live on?"
|
||||
include = ["message.output_text.logprobs"]
|
||||
|
||||
# Create a response without include["message.output_text.logprobs"]
|
||||
response_w_o_logprobs = client_with_models.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify we got one output message and no logprobs
|
||||
assert len(response_w_o_logprobs.output) == 1
|
||||
message_outputs = [output for output in response_w_o_logprobs.output if output.type == "message"]
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
assert message_outputs[0].content[0].logprobs is None, "Expected no logprobs in the returned response"
|
||||
|
||||
# Create a response with include["message.output_text.logprobs"]
|
||||
response_with_logprobs = client_with_models.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
stream=False,
|
||||
include=include,
|
||||
)
|
||||
|
||||
# Verify we got one output message and output message has logprobs
|
||||
assert len(response_with_logprobs.output) == 1
|
||||
message_outputs = [output for output in response_with_logprobs.output if output.type == "message"]
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
assert message_outputs[0].content[0].logprobs is not None, (
|
||||
"Expected logprobs in the returned response, but none were returned"
|
||||
)
|
||||
|
||||
|
||||
def test_include_logprobs_streaming(client_with_models, text_model_id):
|
||||
"""Test logprobs inclusion in responses with the include parameter."""
|
||||
|
||||
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||
|
||||
input = "Which planet do humans live on?"
|
||||
include = ["message.output_text.logprobs"]
|
||||
|
||||
# Create a streaming response with include["message.output_text.logprobs"]
|
||||
stream = client_with_models.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
stream=True,
|
||||
include=include,
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if chunk.type == "response.completed":
|
||||
message_outputs = [output for output in chunk.response.output if output.type == "message"]
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
assert message_outputs[0].content[0].logprobs is not None, (
|
||||
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
|
||||
)
|
||||
elif chunk.type == "response.output_item.done":
|
||||
content = chunk.item.content
|
||||
assert len(content) == 1, f"Expected one content object, got {len(content)}"
|
||||
assert content[0].logprobs is not None, (
|
||||
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
|
||||
)
|
||||
elif chunk.type in ["response.output_text.delta", "response.output_text.done"]:
|
||||
assert chunk.logprobs is not None, (
|
||||
f"Expected logprobs in the returned chunk ({chunk.type=}), but none were returned"
|
||||
)
|
||||
elif chunk.type == "response.content_part.done":
|
||||
assert chunk.part.logprobs is None, f"Expected no logprobs in the returned chunk ({chunk.type=})"
|
||||
|
||||
|
||||
def test_include_logprobs_with_web_search(client_with_models, text_model_id):
|
||||
"""Test include logprobs with built-in tool."""
|
||||
|
||||
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||
|
||||
input = "Search for a positive news story from today."
|
||||
include = ["message.output_text.logprobs"]
|
||||
tools = [
|
||||
{
|
||||
"type": "web_search",
|
||||
}
|
||||
]
|
||||
|
||||
# Create a response with built-in tool and include["message.output_text.logprobs"]
|
||||
response = client_with_models.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
stream=False,
|
||||
include=include,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Verify we got one built-in tool call and output message has logprobs
|
||||
assert len(response.output) >= 2
|
||||
assert response.output[0].type == "web_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
message_outputs = [output for output in response.output if output.type == "message"]
|
||||
assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"
|
||||
assert message_outputs[0].content[0].logprobs is not None, (
|
||||
"Expected logprobs in the returned response, but none were returned"
|
||||
)
|
||||
|
||||
|
||||
def test_include_logprobs_with_function_tools(client_with_models, text_model_id):
|
||||
"""Test include logprobs with function tools."""
|
||||
|
||||
skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id)
|
||||
|
||||
input = "What is the weather in Paris?"
|
||||
include = ["message.output_text.logprobs"]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information for a specified location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name (e.g., 'New York', 'London')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Create a response with function tool and include["message.output_text.logprobs"]
|
||||
response = client_with_models.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
stream=False,
|
||||
include=include,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Verify we got one function tool call and no logprobs
|
||||
assert len(response.output) == 1
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].name == "get_weather"
|
||||
assert response.output[0].status == "completed"
|
||||
message_outputs = [output for output in response.output if output.type == "message"]
|
||||
assert len(message_outputs) == 0, f"Expected no message output, got {len(message_outputs)}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue