mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
minor cleanup and prompt update
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
38205492f8
commit
eb8acd24b0
6 changed files with 30 additions and 47 deletions
|
|
@ -97,8 +97,8 @@ class StreamingResponseOrchestrator:
|
|||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# file mapping for annotations
|
||||
self.file_mapping: dict[str, str] = {}
|
||||
# mapping for annotations
|
||||
self.citation_files: dict[str, str] = {}
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
|
|
@ -163,7 +163,7 @@ class StreamingResponseOrchestrator:
|
|||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice, self.file_mapping))
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files))
|
||||
|
||||
# Execute tool calls and coordinate results
|
||||
async for stream_event in self._coordinate_tool_execution(
|
||||
|
|
@ -475,8 +475,8 @@ class StreamingResponseOrchestrator:
|
|||
tool_call_log = result.final_output_message
|
||||
tool_response_message = result.final_input_message
|
||||
self.sequence_number = result.sequence_number
|
||||
if result.file_mapping:
|
||||
self.file_mapping.update(result.file_mapping)
|
||||
if result.citation_files:
|
||||
self.citation_files.update(result.citation_files)
|
||||
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
|
|
@ -562,9 +562,7 @@ class StreamingResponseOrchestrator:
|
|||
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
openai_tool = make_openai_tool(tool_name, tool)
|
||||
logger.debug(f"Adding file_search tool as knowledge_search: {openai_tool}")
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
|
||||
yield stream_event
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ class ToolExecutor:
|
|||
sequence_number=sequence_number,
|
||||
final_output_message=output_message,
|
||||
final_input_message=input_message,
|
||||
file_mapping=result.metadata.get("_annotation_file_mapping") if result and result.metadata else None,
|
||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
|
|
@ -158,8 +158,10 @@ class ToolExecutor:
|
|||
|
||||
citation_instruction = ""
|
||||
if unique_files:
|
||||
citation_instruction = " Cite sources at the end of each sentence, after punctuation, using `<|file-id|>` (e.g. .<|file-Cn3MSNn72ENTiiq11Qda4A|>)."
|
||||
citation_instruction += " Use only the file IDs provided (do not invent new ones)."
|
||||
citation_instruction = " Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.')."
|
||||
citation_instruction += (
|
||||
" Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
|
||||
)
|
||||
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
|
|
@ -168,7 +170,7 @@ class ToolExecutor:
|
|||
)
|
||||
|
||||
# handling missing attributes for old versions
|
||||
annotation_file_mapping = {
|
||||
citation_files = {
|
||||
(r.file_id or (r.attributes.get("document_id") if r.attributes else None)): r.filename
|
||||
or (r.attributes.get("filename") if r.attributes else None)
|
||||
or "unknown"
|
||||
|
|
@ -181,7 +183,7 @@ class ToolExecutor:
|
|||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
"scores": [r.score for r in search_results],
|
||||
"_annotation_file_mapping": annotation_file_mapping,
|
||||
"citation_files": citation_files,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class ToolExecutionResult(BaseModel):
|
|||
sequence_number: int
|
||||
final_output_message: OpenAIResponseOutput | None = None
|
||||
final_input_message: OpenAIMessageParam | None = None
|
||||
file_mapping: dict[str, str] | None = None
|
||||
citation_files: dict[str, str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
|
|
@ -49,7 +48,7 @@ from llama_stack.apis.inference import (
|
|||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
choice: OpenAIChoice, file_mapping: dict[str, str] | None = None
|
||||
choice: OpenAIChoice, citation_files: dict[str, str] | None = None
|
||||
) -> OpenAIResponseMessage:
|
||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||
output_content = ""
|
||||
|
|
@ -62,7 +61,7 @@ async def convert_chat_choice_to_response_message(
|
|||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
|
||||
annotations, clean_text = _extract_citations_from_text(output_content, file_mapping or {})
|
||||
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
|
|
@ -207,65 +206,50 @@ async def get_message_type_by_role(role: str):
|
|||
return role_to_type.get(role)
|
||||
|
||||
|
||||
def _is_punct(ch: str) -> bool:
|
||||
return bool(ch) and unicodedata.category(ch).startswith("P")
|
||||
|
||||
|
||||
def _is_word_char(ch: str) -> bool:
|
||||
return bool(ch) and (ch.isalnum() or ch == "_")
|
||||
|
||||
|
||||
def _extract_citations_from_text(
|
||||
text: str, file_mapping: dict[str, str]
|
||||
text: str, citation_files: dict[str, str]
|
||||
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
|
||||
"""Extract citation markers from text and create annotations
|
||||
|
||||
Args:
|
||||
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
|
||||
file_mapping: Dictionary mapping file_id to filename
|
||||
citation_files: Dictionary mapping file_id to filename
|
||||
|
||||
Returns:
|
||||
Tuple of (annotations_list, clean_text_without_markers)
|
||||
"""
|
||||
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
|
||||
|
||||
annotations: list[OpenAIResponseAnnotationFileCitation] = []
|
||||
parts: list[str] = []
|
||||
annotations = []
|
||||
parts = []
|
||||
total_len = 0
|
||||
last_end = 0
|
||||
|
||||
for m in file_id_regex.finditer(text):
|
||||
# segment before the marker
|
||||
prefix = text[last_end : m.start()]
|
||||
# remove trailing space
|
||||
|
||||
# drop one space if it exists (since marker is at sentence end)
|
||||
if prefix.endswith(" "):
|
||||
prefix = prefix[:-1]
|
||||
|
||||
# skip all spaces after the marker
|
||||
j = m.end()
|
||||
while j < len(text) and text[j].isspace():
|
||||
j += 1
|
||||
|
||||
# append normalized prefix
|
||||
parts.append(prefix)
|
||||
total_len += len(prefix)
|
||||
|
||||
# point to the next visible character
|
||||
fid = m.group("file_id")
|
||||
if fid in file_mapping:
|
||||
fid = m.group(1)
|
||||
if fid in citation_files:
|
||||
annotations.append(
|
||||
OpenAIResponseAnnotationFileCitation(
|
||||
file_id=fid,
|
||||
filename=file_mapping[fid],
|
||||
index=total_len,
|
||||
filename=citation_files[fid],
|
||||
index=total_len, # index points to punctuation
|
||||
)
|
||||
)
|
||||
|
||||
last_end = j
|
||||
last_end = m.end()
|
||||
|
||||
# append remaining part
|
||||
parts.append(text[last_end:])
|
||||
cleaned_text = "".join(parts)
|
||||
annotations.sort(key=lambda a: a.index)
|
||||
return annotations, cleaned_text
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -333,6 +333,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
content=result.content or [],
|
||||
metadata={
|
||||
**(result.metadata or {}),
|
||||
"_annotation_file_mapping": getattr(result, "annotation_file_mapping", None),
|
||||
"citation_files": getattr(result, "citation_files", None),
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -361,8 +361,7 @@ class TestExtractCitationsFromText:
|
|||
|
||||
assert cleaned_text == expected_clean_text
|
||||
assert annotations == expected_annotations
|
||||
# OpenAI typically cites at the end of the sentence but we support the middle just in case,
|
||||
# which makes the position the start of the next word.
|
||||
# OpenAI cites at the end of the sentence
|
||||
assert cleaned_text[expected_annotations[0].index] == "."
|
||||
assert cleaned_text[expected_annotations[1].index] == "?"
|
||||
assert cleaned_text[expected_annotations[2].index] == "!"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue