diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 0bb524f5c..8a662e6db 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -97,6 +97,8 @@ class StreamingResponseOrchestrator: self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {} # Track final messages after all tool executions self.final_messages: list[OpenAIMessageParam] = [] + # mapping for annotations + self.citation_files: dict[str, str] = {} async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: # Initialize output messages @@ -126,6 +128,7 @@ class StreamingResponseOrchestrator: # Text is the default response format for chat completion so don't need to pass it # (some providers don't support non-empty response_format when tools are present) response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format + logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}") completion_result = await self.inference_api.openai_chat_completion( model=self.ctx.model, messages=messages, @@ -160,7 +163,7 @@ class StreamingResponseOrchestrator: # Handle choices with no tool calls for choice in current_response.choices: if not (choice.message.tool_calls and self.ctx.response_tools): - output_messages.append(await convert_chat_choice_to_response_message(choice)) + output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files)) # Execute tool calls and coordinate results async for stream_event in self._coordinate_tool_execution( @@ -211,6 +214,8 @@ class StreamingResponseOrchestrator: for choice in current_response.choices: next_turn_messages.append(choice.message) + logger.debug(f"Choice message content: {choice.message.content}") + logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}") if choice.message.tool_calls and self.ctx.response_tools: for tool_call in choice.message.tool_calls: @@ -470,6 +475,8 @@ class StreamingResponseOrchestrator: tool_call_log = result.final_output_message tool_response_message = result.final_input_message self.sequence_number = result.sequence_number + if result.citation_files: + self.citation_files.update(result.citation_files) if tool_call_log: output_messages.append(tool_call_log) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index b028c018b..b33b47454 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -94,7 +94,10 @@ class ToolExecutor: # Yield the final result yield ToolExecutionResult( - sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message + sequence_number=sequence_number, + final_output_message=output_message, + final_input_message=input_message, + citation_files=result.metadata.get("citation_files") if result and result.metadata else None, ) async def _execute_knowledge_search_via_vector_store( @@ -129,8 +132,6 @@ class ToolExecutor: for results in all_results: search_results.extend(results) - # Convert search results to tool result format matching memory.py - # Format the results as interleaved content similar to memory.py content_items = [] content_items.append( TextContentItem( @@ -138,27 +139,58 @@ class ToolExecutor: ) ) + unique_files = set() for i, result_item in enumerate(search_results): chunk_text = result_item.content[0].text if result_item.content else "" - metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" + # Get file_id from attributes if result_item.file_id is empty + file_id = result_item.file_id or ( + result_item.attributes.get("document_id") if result_item.attributes else None + ) + metadata_text = f"document_id: {file_id}, score: {result_item.score}" if result_item.attributes: metadata_text += f", attributes: {result_item.attributes}" - text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" + + text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n" content_items.append(TextContentItem(text=text_content)) + unique_files.add(file_id) content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) + + citation_instruction = "" + if unique_files: + citation_instruction = ( + " Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). " + "Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)." + ) + content_items.append( TextContentItem( - text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', + text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n', ) ) + # handling missing attributes for old versions + citation_files = {} + for result in search_results: + file_id = result.file_id + if not file_id and result.attributes: + file_id = result.attributes.get("document_id") + + filename = result.filename + if not filename and result.attributes: + filename = result.attributes.get("filename") + if not filename: + filename = "unknown" + + citation_files[file_id] = filename + return ToolInvocationResult( content=content_items, metadata={ "document_ids": [r.file_id for r in search_results], "chunks": [r.content[0].text if r.content else "" for r in search_results], "scores": [r.score for r in search_results], + "citation_files": citation_files, }, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/llama_stack/providers/inline/agents/meta_reference/responses/types.py index d3b5a16bd..fd5f44242 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/types.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -27,6 +27,7 @@ class ToolExecutionResult(BaseModel): sequence_number: int final_output_message: OpenAIResponseOutput | None = None final_input_message: OpenAIMessageParam | None = None + citation_files: dict[str, str] | None = None @dataclass diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 310a88298..5b013b9c4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -4,9 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import re import uuid from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseAnnotationFileCitation, OpenAIResponseInput, OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputMessageContent, @@ -45,7 +47,9 @@ from llama_stack.apis.inference import ( ) -async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage: +async def convert_chat_choice_to_response_message( + choice: OpenAIChoice, citation_files: dict[str, str] | None = None +) -> OpenAIResponseMessage: """Convert an OpenAI Chat Completion choice into an OpenAI Response output message.""" output_content = "" if isinstance(choice.message.content, str): @@ -57,9 +61,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" ) + annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {}) + return OpenAIResponseMessage( id=f"msg_{uuid.uuid4()}", - content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], + content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)], status="completed", role="assistant", ) @@ -200,6 +206,53 @@ async def get_message_type_by_role(role: str): return role_to_type.get(role) +def _extract_citations_from_text( + text: str, citation_files: dict[str, str] +) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]: + """Extract citation markers from text and create annotations + + Args: + text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A] + citation_files: Dictionary mapping file_id to filename + + Returns: + Tuple of (annotations_list, clean_text_without_markers) + """ + file_id_regex = re.compile(r"<\|(?Pfile-[A-Za-z0-9_-]+)\|>") + + annotations = [] + parts = [] + total_len = 0 + last_end = 0 + + for m in file_id_regex.finditer(text): + # segment before the marker + prefix = text[last_end : m.start()] + + # drop one space if it exists (since marker is at sentence end) + if prefix.endswith(" "): + prefix = prefix[:-1] + + parts.append(prefix) + total_len += len(prefix) + + fid = m.group(1) + if fid in citation_files: + annotations.append( + OpenAIResponseAnnotationFileCitation( + file_id=fid, + filename=citation_files[fid], + index=total_len, # index points to punctuation + ) + ) + + last_end = m.end() + + parts.append(text[last_end:]) + cleaned_text = "".join(parts) + return annotations, cleaned_text + + def is_function_tool_call( tool_call: OpenAIChatCompletionToolCall, tools: list[OpenAIResponseInputTool], diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index c8499a9b8..aac86a056 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -331,5 +331,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti return ToolInvocationResult( content=result.content or [], - metadata=result.metadata, + metadata={ + **(result.metadata or {}), + "citation_files": getattr(result, "citation_files", None), + }, ) diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 0d0aa25a4..97079c3b3 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -587,7 +587,7 @@ class OpenAIVectorStoreMixin(ABC): content = self._chunk_to_vector_store_content(chunk) response_data_item = VectorStoreSearchResponse( - file_id=chunk.metadata.get("file_id", ""), + file_id=chunk.metadata.get("document_id", ""), filename=chunk.metadata.get("filename", ""), score=score, attributes=chunk.metadata, @@ -746,12 +746,15 @@ class OpenAIVectorStoreMixin(ABC): content = content_from_data_and_mime_type(content_response.body, mime_type) + chunk_attributes = attributes.copy() + chunk_attributes["filename"] = file_response.filename + chunks = make_overlapped_chunks( file_id, content, max_chunk_size_tokens, chunk_overlap_tokens, - attributes, + chunk_attributes, ) if not chunks: vector_store_file_object.status = "failed" diff --git a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py index 187540f82..2698b88c8 100644 --- a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py +++ b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py @@ -8,6 +8,7 @@ import pytest from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseAnnotationFileCitation, OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, @@ -35,6 +36,7 @@ from llama_stack.apis.inference import ( OpenAIUserMessageParam, ) from llama_stack.providers.inline.agents.meta_reference.responses.utils import ( + _extract_citations_from_text, convert_chat_choice_to_response_message, convert_response_content_to_chat_content, convert_response_input_to_chat_messages, @@ -340,3 +342,26 @@ class TestIsFunctionToolCall: result = is_function_tool_call(tool_call, tools) assert result is False + + +class TestExtractCitationsFromText: + def test_extract_citations_and_annotations(self): + text = "Start [not-a-file]. New source <|file-abc123|>. " + text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation." + file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"} + + annotations, cleaned_text = _extract_citations_from_text(text, file_mapping) + + expected_annotations = [ + OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30), + OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44), + OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59), + ] + expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation." + + assert cleaned_text == expected_clean_text + assert annotations == expected_annotations + # OpenAI cites at the end of the sentence + assert cleaned_text[expected_annotations[0].index] == "." + assert cleaned_text[expected_annotations[1].index] == "?" + assert cleaned_text[expected_annotations[2].index] == "!"