mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 00:52:38 +00:00
minor cleanup and prompt update
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
38205492f8
commit
eb8acd24b0
6 changed files with 30 additions and 47 deletions
|
|
@ -97,8 +97,8 @@ class StreamingResponseOrchestrator:
|
||||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||||
# Track final messages after all tool executions
|
# Track final messages after all tool executions
|
||||||
self.final_messages: list[OpenAIMessageParam] = []
|
self.final_messages: list[OpenAIMessageParam] = []
|
||||||
# file mapping for annotations
|
# mapping for annotations
|
||||||
self.file_mapping: dict[str, str] = {}
|
self.citation_files: dict[str, str] = {}
|
||||||
|
|
||||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# Initialize output messages
|
# Initialize output messages
|
||||||
|
|
@ -163,7 +163,7 @@ class StreamingResponseOrchestrator:
|
||||||
# Handle choices with no tool calls
|
# Handle choices with no tool calls
|
||||||
for choice in current_response.choices:
|
for choice in current_response.choices:
|
||||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||||
output_messages.append(await convert_chat_choice_to_response_message(choice, self.file_mapping))
|
output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files))
|
||||||
|
|
||||||
# Execute tool calls and coordinate results
|
# Execute tool calls and coordinate results
|
||||||
async for stream_event in self._coordinate_tool_execution(
|
async for stream_event in self._coordinate_tool_execution(
|
||||||
|
|
@ -475,8 +475,8 @@ class StreamingResponseOrchestrator:
|
||||||
tool_call_log = result.final_output_message
|
tool_call_log = result.final_output_message
|
||||||
tool_response_message = result.final_input_message
|
tool_response_message = result.final_input_message
|
||||||
self.sequence_number = result.sequence_number
|
self.sequence_number = result.sequence_number
|
||||||
if result.file_mapping:
|
if result.citation_files:
|
||||||
self.file_mapping.update(result.file_mapping)
|
self.citation_files.update(result.citation_files)
|
||||||
|
|
||||||
if tool_call_log:
|
if tool_call_log:
|
||||||
output_messages.append(tool_call_log)
|
output_messages.append(tool_call_log)
|
||||||
|
|
@ -562,9 +562,7 @@ class StreamingResponseOrchestrator:
|
||||||
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||||
if not tool:
|
if not tool:
|
||||||
raise ValueError(f"Tool {tool_name} not found")
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
openai_tool = make_openai_tool(tool_name, tool)
|
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||||
logger.debug(f"Adding file_search tool as knowledge_search: {openai_tool}")
|
|
||||||
self.ctx.chat_tools.append(openai_tool)
|
|
||||||
elif input_tool.type == "mcp":
|
elif input_tool.type == "mcp":
|
||||||
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
|
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
|
||||||
yield stream_event
|
yield stream_event
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ class ToolExecutor:
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
final_output_message=output_message,
|
final_output_message=output_message,
|
||||||
final_input_message=input_message,
|
final_input_message=input_message,
|
||||||
file_mapping=result.metadata.get("_annotation_file_mapping") if result and result.metadata else None,
|
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_knowledge_search_via_vector_store(
|
async def _execute_knowledge_search_via_vector_store(
|
||||||
|
|
@ -158,8 +158,10 @@ class ToolExecutor:
|
||||||
|
|
||||||
citation_instruction = ""
|
citation_instruction = ""
|
||||||
if unique_files:
|
if unique_files:
|
||||||
citation_instruction = " Cite sources at the end of each sentence, after punctuation, using `<|file-id|>` (e.g. .<|file-Cn3MSNn72ENTiiq11Qda4A|>)."
|
citation_instruction = " Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.')."
|
||||||
citation_instruction += " Use only the file IDs provided (do not invent new ones)."
|
citation_instruction += (
|
||||||
|
" Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
|
||||||
|
)
|
||||||
|
|
||||||
content_items.append(
|
content_items.append(
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
|
|
@ -168,7 +170,7 @@ class ToolExecutor:
|
||||||
)
|
)
|
||||||
|
|
||||||
# handling missing attributes for old versions
|
# handling missing attributes for old versions
|
||||||
annotation_file_mapping = {
|
citation_files = {
|
||||||
(r.file_id or (r.attributes.get("document_id") if r.attributes else None)): r.filename
|
(r.file_id or (r.attributes.get("document_id") if r.attributes else None)): r.filename
|
||||||
or (r.attributes.get("filename") if r.attributes else None)
|
or (r.attributes.get("filename") if r.attributes else None)
|
||||||
or "unknown"
|
or "unknown"
|
||||||
|
|
@ -181,7 +183,7 @@ class ToolExecutor:
|
||||||
"document_ids": [r.file_id for r in search_results],
|
"document_ids": [r.file_id for r in search_results],
|
||||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||||
"scores": [r.score for r in search_results],
|
"scores": [r.score for r in search_results],
|
||||||
"_annotation_file_mapping": annotation_file_mapping,
|
"citation_files": citation_files,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ class ToolExecutionResult(BaseModel):
|
||||||
sequence_number: int
|
sequence_number: int
|
||||||
final_output_message: OpenAIResponseOutput | None = None
|
final_output_message: OpenAIResponseOutput | None = None
|
||||||
final_input_message: OpenAIMessageParam | None = None
|
final_input_message: OpenAIMessageParam | None = None
|
||||||
file_mapping: dict[str, str] | None = None
|
citation_files: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
|
@ -49,7 +48,7 @@ from llama_stack.apis.inference import (
|
||||||
|
|
||||||
|
|
||||||
async def convert_chat_choice_to_response_message(
|
async def convert_chat_choice_to_response_message(
|
||||||
choice: OpenAIChoice, file_mapping: dict[str, str] | None = None
|
choice: OpenAIChoice, citation_files: dict[str, str] | None = None
|
||||||
) -> OpenAIResponseMessage:
|
) -> OpenAIResponseMessage:
|
||||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||||
output_content = ""
|
output_content = ""
|
||||||
|
|
@ -62,7 +61,7 @@ async def convert_chat_choice_to_response_message(
|
||||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
annotations, clean_text = _extract_citations_from_text(output_content, file_mapping or {})
|
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||||
|
|
||||||
return OpenAIResponseMessage(
|
return OpenAIResponseMessage(
|
||||||
id=f"msg_{uuid.uuid4()}",
|
id=f"msg_{uuid.uuid4()}",
|
||||||
|
|
@ -207,65 +206,50 @@ async def get_message_type_by_role(role: str):
|
||||||
return role_to_type.get(role)
|
return role_to_type.get(role)
|
||||||
|
|
||||||
|
|
||||||
def _is_punct(ch: str) -> bool:
|
|
||||||
return bool(ch) and unicodedata.category(ch).startswith("P")
|
|
||||||
|
|
||||||
|
|
||||||
def _is_word_char(ch: str) -> bool:
|
|
||||||
return bool(ch) and (ch.isalnum() or ch == "_")
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_citations_from_text(
|
def _extract_citations_from_text(
|
||||||
text: str, file_mapping: dict[str, str]
|
text: str, citation_files: dict[str, str]
|
||||||
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
|
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
|
||||||
"""Extract citation markers from text and create annotations
|
"""Extract citation markers from text and create annotations
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
|
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
|
||||||
file_mapping: Dictionary mapping file_id to filename
|
citation_files: Dictionary mapping file_id to filename
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (annotations_list, clean_text_without_markers)
|
Tuple of (annotations_list, clean_text_without_markers)
|
||||||
"""
|
"""
|
||||||
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
|
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
|
||||||
|
|
||||||
annotations: list[OpenAIResponseAnnotationFileCitation] = []
|
annotations = []
|
||||||
parts: list[str] = []
|
parts = []
|
||||||
total_len = 0
|
total_len = 0
|
||||||
last_end = 0
|
last_end = 0
|
||||||
|
|
||||||
for m in file_id_regex.finditer(text):
|
for m in file_id_regex.finditer(text):
|
||||||
|
# segment before the marker
|
||||||
prefix = text[last_end : m.start()]
|
prefix = text[last_end : m.start()]
|
||||||
# remove trailing space
|
|
||||||
|
# drop one space if it exists (since marker is at sentence end)
|
||||||
if prefix.endswith(" "):
|
if prefix.endswith(" "):
|
||||||
prefix = prefix[:-1]
|
prefix = prefix[:-1]
|
||||||
|
|
||||||
# skip all spaces after the marker
|
|
||||||
j = m.end()
|
|
||||||
while j < len(text) and text[j].isspace():
|
|
||||||
j += 1
|
|
||||||
|
|
||||||
# append normalized prefix
|
|
||||||
parts.append(prefix)
|
parts.append(prefix)
|
||||||
total_len += len(prefix)
|
total_len += len(prefix)
|
||||||
|
|
||||||
# point to the next visible character
|
fid = m.group(1)
|
||||||
fid = m.group("file_id")
|
if fid in citation_files:
|
||||||
if fid in file_mapping:
|
|
||||||
annotations.append(
|
annotations.append(
|
||||||
OpenAIResponseAnnotationFileCitation(
|
OpenAIResponseAnnotationFileCitation(
|
||||||
file_id=fid,
|
file_id=fid,
|
||||||
filename=file_mapping[fid],
|
filename=citation_files[fid],
|
||||||
index=total_len,
|
index=total_len, # index points to punctuation
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
last_end = j
|
last_end = m.end()
|
||||||
|
|
||||||
# append remaining part
|
|
||||||
parts.append(text[last_end:])
|
parts.append(text[last_end:])
|
||||||
cleaned_text = "".join(parts)
|
cleaned_text = "".join(parts)
|
||||||
annotations.sort(key=lambda a: a.index)
|
|
||||||
return annotations, cleaned_text
|
return annotations, cleaned_text
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -333,6 +333,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
content=result.content or [],
|
content=result.content or [],
|
||||||
metadata={
|
metadata={
|
||||||
**(result.metadata or {}),
|
**(result.metadata or {}),
|
||||||
"_annotation_file_mapping": getattr(result, "annotation_file_mapping", None),
|
"citation_files": getattr(result, "citation_files", None),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -361,8 +361,7 @@ class TestExtractCitationsFromText:
|
||||||
|
|
||||||
assert cleaned_text == expected_clean_text
|
assert cleaned_text == expected_clean_text
|
||||||
assert annotations == expected_annotations
|
assert annotations == expected_annotations
|
||||||
# OpenAI typically cites at the end of the sentence but we support the middle just in case,
|
# OpenAI cites at the end of the sentence
|
||||||
# which makes the position the start of the next word.
|
|
||||||
assert cleaned_text[expected_annotations[0].index] == "."
|
assert cleaned_text[expected_annotations[0].index] == "."
|
||||||
assert cleaned_text[expected_annotations[1].index] == "?"
|
assert cleaned_text[expected_annotations[1].index] == "?"
|
||||||
assert cleaned_text[expected_annotations[2].index] == "!"
|
assert cleaned_text[expected_annotations[2].index] == "!"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue