feat: Enabling Annotations in Responses

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-05 23:51:41 -04:00
parent c21bb0e837
commit 38205492f8
7 changed files with 147 additions and 13 deletions

View file

@ -97,6 +97,8 @@ class StreamingResponseOrchestrator:
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# file mapping for annotations
self.file_mapping: dict[str, str] = {}
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
@ -126,6 +128,7 @@ class StreamingResponseOrchestrator:
# Text is the default response format for chat completion so don't need to pass it
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
completion_result = await self.inference_api.openai_chat_completion(
model=self.ctx.model,
messages=messages,
@ -160,7 +163,7 @@ class StreamingResponseOrchestrator:
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(await convert_chat_choice_to_response_message(choice))
output_messages.append(await convert_chat_choice_to_response_message(choice, self.file_mapping))
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
@ -211,6 +214,8 @@ class StreamingResponseOrchestrator:
for choice in current_response.choices:
next_turn_messages.append(choice.message)
logger.debug(f"Choice message content: {choice.message.content}")
logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")
if choice.message.tool_calls and self.ctx.response_tools:
for tool_call in choice.message.tool_calls:
@ -470,6 +475,8 @@ class StreamingResponseOrchestrator:
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
self.sequence_number = result.sequence_number
if result.file_mapping:
self.file_mapping.update(result.file_mapping)
if tool_call_log:
output_messages.append(tool_call_log)
@ -555,7 +562,9 @@ class StreamingResponseOrchestrator:
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
openai_tool = make_openai_tool(tool_name, tool)
logger.debug(f"Adding file_search tool as knowledge_search: {openai_tool}")
self.ctx.chat_tools.append(openai_tool)
elif input_tool.type == "mcp":
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
yield stream_event

View file

@ -94,7 +94,10 @@ class ToolExecutor:
# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
sequence_number=sequence_number,
final_output_message=output_message,
final_input_message=input_message,
file_mapping=result.metadata.get("_annotation_file_mapping") if result and result.metadata else None,
)
async def _execute_knowledge_search_via_vector_store(
@ -129,8 +132,6 @@ class ToolExecutor:
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
@ -138,27 +139,49 @@ class ToolExecutor:
)
)
unique_files = set()
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
# Get file_id from attributes if result_item.file_id is empty
file_id = result_item.file_id or (
result_item.attributes.get("document_id") if result_item.attributes else None
)
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
unique_files.add(file_id)
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
citation_instruction = ""
if unique_files:
citation_instruction = " Cite sources at the end of each sentence, after punctuation, using `<|file-id|>` (e.g. .<|file-Cn3MSNn72ENTiiq11Qda4A|>)."
citation_instruction += " Use only the file IDs provided (do not invent new ones)."
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
)
)
# handling missing attributes for old versions
annotation_file_mapping = {
(r.file_id or (r.attributes.get("document_id") if r.attributes else None)): r.filename
or (r.attributes.get("filename") if r.attributes else None)
or "unknown"
for r in search_results
}
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
"_annotation_file_mapping": annotation_file_mapping,
},
)

View file

@ -27,6 +27,7 @@ class ToolExecutionResult(BaseModel):
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
file_mapping: dict[str, str] | None = None
@dataclass

View file

@ -4,9 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import unicodedata
import uuid
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
@ -45,7 +48,9 @@ from llama_stack.apis.inference import (
)
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
async def convert_chat_choice_to_response_message(
choice: OpenAIChoice, file_mapping: dict[str, str] | None = None
) -> OpenAIResponseMessage:
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
output_content = ""
if isinstance(choice.message.content, str):
@ -57,9 +62,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
annotations, clean_text = _extract_citations_from_text(output_content, file_mapping or {})
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
status="completed",
role="assistant",
)
@ -200,6 +207,68 @@ async def get_message_type_by_role(role: str):
return role_to_type.get(role)
def _is_punct(ch: str) -> bool:
return bool(ch) and unicodedata.category(ch).startswith("P")
def _is_word_char(ch: str) -> bool:
return bool(ch) and (ch.isalnum() or ch == "_")
def _extract_citations_from_text(
text: str, file_mapping: dict[str, str]
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
"""Extract citation markers from text and create annotations
Args:
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
file_mapping: Dictionary mapping file_id to filename
Returns:
Tuple of (annotations_list, clean_text_without_markers)
"""
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
annotations: list[OpenAIResponseAnnotationFileCitation] = []
parts: list[str] = []
total_len = 0
last_end = 0
for m in file_id_regex.finditer(text):
prefix = text[last_end : m.start()]
# remove trailing space
if prefix.endswith(" "):
prefix = prefix[:-1]
# skip all spaces after the marker
j = m.end()
while j < len(text) and text[j].isspace():
j += 1
# append normalized prefix
parts.append(prefix)
total_len += len(prefix)
# point to the next visible character
fid = m.group("file_id")
if fid in file_mapping:
annotations.append(
OpenAIResponseAnnotationFileCitation(
file_id=fid,
filename=file_mapping[fid],
index=total_len,
)
)
last_end = j
# append remaining part
parts.append(text[last_end:])
cleaned_text = "".join(parts)
annotations.sort(key=lambda a: a.index)
return annotations, cleaned_text
def is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],

View file

@ -331,5 +331,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return ToolInvocationResult(
content=result.content or [],
metadata=result.metadata,
metadata={
**(result.metadata or {}),
"_annotation_file_mapping": getattr(result, "annotation_file_mapping", None),
},
)

View file

@ -457,7 +457,7 @@ class OpenAIVectorStoreMixin(ABC):
content = self._chunk_to_vector_store_content(chunk)
response_data_item = VectorStoreSearchResponse(
file_id=chunk.metadata.get("file_id", ""),
file_id=chunk.metadata.get("document_id", ""),
filename=chunk.metadata.get("filename", ""),
score=score,
attributes=chunk.metadata,
@ -608,12 +608,15 @@ class OpenAIVectorStoreMixin(ABC):
content = content_from_data_and_mime_type(content_response.body, mime_type)
chunk_attributes = attributes.copy()
chunk_attributes["filename"] = file_response.filename
chunks = make_overlapped_chunks(
file_id,
content,
max_chunk_size_tokens,
chunk_overlap_tokens,
attributes,
chunk_attributes,
)
if not chunks:

View file

@ -8,6 +8,7 @@
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
@ -35,6 +36,7 @@ from llama_stack.apis.inference import (
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
_extract_citations_from_text,
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
@ -340,3 +342,27 @@ class TestIsFunctionToolCall:
result = is_function_tool_call(tool_call, tools)
assert result is False
class TestExtractCitationsFromText:
def test_extract_citations_and_annotations(self):
text = "Start [not-a-file]. New source <|file-abc123|>. "
text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}
annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)
expected_annotations = [
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
]
expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."
assert cleaned_text == expected_clean_text
assert annotations == expected_annotations
# OpenAI typically cites at the end of the sentence but we support the middle just in case,
# which makes the position the start of the next word.
assert cleaned_text[expected_annotations[0].index] == "."
assert cleaned_text[expected_annotations[1].index] == "?"
assert cleaned_text[expected_annotations[2].index] == "!"