From eb8acd24b0bed0610615502baa67a28e5c9ab294 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Mon, 6 Oct 2025 16:14:23 -0400 Subject: [PATCH] minor cleanup and prompt update Signed-off-by: Francisco Javier Arceo --- .../meta_reference/responses/streaming.py | 14 +++--- .../meta_reference/responses/tool_executor.py | 12 ++--- .../agents/meta_reference/responses/types.py | 2 +- .../agents/meta_reference/responses/utils.py | 44 ++++++------------- .../inline/tool_runtime/rag/memory.py | 2 +- .../test_response_conversion_utils.py | 3 +- 6 files changed, 30 insertions(+), 47 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 649e38b51..8a662e6db 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -97,8 +97,8 @@ class StreamingResponseOrchestrator: self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {} # Track final messages after all tool executions self.final_messages: list[OpenAIMessageParam] = [] - # file mapping for annotations - self.file_mapping: dict[str, str] = {} + # mapping for annotations + self.citation_files: dict[str, str] = {} async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: # Initialize output messages @@ -163,7 +163,7 @@ class StreamingResponseOrchestrator: # Handle choices with no tool calls for choice in current_response.choices: if not (choice.message.tool_calls and self.ctx.response_tools): - output_messages.append(await convert_chat_choice_to_response_message(choice, self.file_mapping)) + output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files)) # Execute tool calls and coordinate results async for stream_event in self._coordinate_tool_execution( @@ -475,8 +475,8 @@ class StreamingResponseOrchestrator: tool_call_log = result.final_output_message tool_response_message = result.final_input_message self.sequence_number = result.sequence_number - if result.file_mapping: - self.file_mapping.update(result.file_mapping) + if result.citation_files: + self.citation_files.update(result.citation_files) if tool_call_log: output_messages.append(tool_call_log) @@ -562,9 +562,7 @@ class StreamingResponseOrchestrator: tool = await self.tool_executor.tool_groups_api.get_tool(tool_name) if not tool: raise ValueError(f"Tool {tool_name} not found") - openai_tool = make_openai_tool(tool_name, tool) - logger.debug(f"Adding file_search tool as knowledge_search: {openai_tool}") - self.ctx.chat_tools.append(openai_tool) + self.ctx.chat_tools.append(make_openai_tool(tool_name, tool)) elif input_tool.type == "mcp": async for stream_event in self._process_mcp_tool(input_tool, output_messages): yield stream_event diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index 4b8c703ad..21a6c31d4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -97,7 +97,7 @@ class ToolExecutor: sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message, - file_mapping=result.metadata.get("_annotation_file_mapping") if result and result.metadata else None, + citation_files=result.metadata.get("citation_files") if result and result.metadata else None, ) async def _execute_knowledge_search_via_vector_store( @@ -158,8 +158,10 @@ class ToolExecutor: citation_instruction = "" if unique_files: - citation_instruction = " Cite sources at the end of each sentence, after punctuation, using `<|file-id|>` (e.g. .<|file-Cn3MSNn72ENTiiq11Qda4A|>)." - citation_instruction += " Use only the file IDs provided (do not invent new ones)." + citation_instruction = " Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.')." + citation_instruction += ( + " Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)." + ) content_items.append( TextContentItem( @@ -168,7 +170,7 @@ class ToolExecutor: ) # handling missing attributes for old versions - annotation_file_mapping = { + citation_files = { (r.file_id or (r.attributes.get("document_id") if r.attributes else None)): r.filename or (r.attributes.get("filename") if r.attributes else None) or "unknown" @@ -181,7 +183,7 @@ class ToolExecutor: "document_ids": [r.file_id for r in search_results], "chunks": [r.content[0].text if r.content else "" for r in search_results], "scores": [r.score for r in search_results], - "_annotation_file_mapping": annotation_file_mapping, + "citation_files": citation_files, }, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/llama_stack/providers/inline/agents/meta_reference/responses/types.py index f730aa33a..fd5f44242 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/types.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -27,7 +27,7 @@ class ToolExecutionResult(BaseModel): sequence_number: int final_output_message: OpenAIResponseOutput | None = None final_input_message: OpenAIMessageParam | None = None - file_mapping: dict[str, str] | None = None + citation_files: dict[str, str] | None = None @dataclass diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index ec89d7999..5b013b9c4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import re -import unicodedata import uuid from llama_stack.apis.agents.openai_responses import ( @@ -49,7 +48,7 @@ from llama_stack.apis.inference import ( async def convert_chat_choice_to_response_message( - choice: OpenAIChoice, file_mapping: dict[str, str] | None = None + choice: OpenAIChoice, citation_files: dict[str, str] | None = None ) -> OpenAIResponseMessage: """Convert an OpenAI Chat Completion choice into an OpenAI Response output message.""" output_content = "" @@ -62,7 +61,7 @@ async def convert_chat_choice_to_response_message( f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" ) - annotations, clean_text = _extract_citations_from_text(output_content, file_mapping or {}) + annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {}) return OpenAIResponseMessage( id=f"msg_{uuid.uuid4()}", @@ -207,65 +206,50 @@ async def get_message_type_by_role(role: str): return role_to_type.get(role) -def _is_punct(ch: str) -> bool: - return bool(ch) and unicodedata.category(ch).startswith("P") - - -def _is_word_char(ch: str) -> bool: - return bool(ch) and (ch.isalnum() or ch == "_") - - def _extract_citations_from_text( - text: str, file_mapping: dict[str, str] + text: str, citation_files: dict[str, str] ) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]: """Extract citation markers from text and create annotations Args: text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A] - file_mapping: Dictionary mapping file_id to filename + citation_files: Dictionary mapping file_id to filename Returns: Tuple of (annotations_list, clean_text_without_markers) """ file_id_regex = re.compile(r"<\|(?Pfile-[A-Za-z0-9_-]+)\|>") - annotations: list[OpenAIResponseAnnotationFileCitation] = [] - parts: list[str] = [] + annotations = [] + parts = [] total_len = 0 last_end = 0 for m in file_id_regex.finditer(text): + # segment before the marker prefix = text[last_end : m.start()] - # remove trailing space + + # drop one space if it exists (since marker is at sentence end) if prefix.endswith(" "): prefix = prefix[:-1] - # skip all spaces after the marker - j = m.end() - while j < len(text) and text[j].isspace(): - j += 1 - - # append normalized prefix parts.append(prefix) total_len += len(prefix) - # point to the next visible character - fid = m.group("file_id") - if fid in file_mapping: + fid = m.group(1) + if fid in citation_files: annotations.append( OpenAIResponseAnnotationFileCitation( file_id=fid, - filename=file_mapping[fid], - index=total_len, + filename=citation_files[fid], + index=total_len, # index points to punctuation ) ) - last_end = j + last_end = m.end() - # append remaining part parts.append(text[last_end:]) cleaned_text = "".join(parts) - annotations.sort(key=lambda a: a.index) return annotations, cleaned_text diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 5503e9ec3..aac86a056 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -333,6 +333,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti content=result.content or [], metadata={ **(result.metadata or {}), - "_annotation_file_mapping": getattr(result, "annotation_file_mapping", None), + "citation_files": getattr(result, "citation_files", None), }, ) diff --git a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py index f0f48bb92..2698b88c8 100644 --- a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py +++ b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py @@ -361,8 +361,7 @@ class TestExtractCitationsFromText: assert cleaned_text == expected_clean_text assert annotations == expected_annotations - # OpenAI typically cites at the end of the sentence but we support the middle just in case, - # which makes the position the start of the next word. + # OpenAI cites at the end of the sentence assert cleaned_text[expected_annotations[0].index] == "." assert cleaned_text[expected_annotations[1].index] == "?" assert cleaned_text[expected_annotations[2].index] == "!"