mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-09 05:08:37 +00:00
Merge 9e61a4ab8c
into sapling-pr-archive-ehhuang
This commit is contained in:
commit
75690a7cc6
20 changed files with 251 additions and 36 deletions
|
@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
|
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
|
||||||
return self.impls_by_provider_id[model.provider_id]
|
return self.impls_by_provider_id[model.provider_id]
|
||||||
|
|
||||||
|
async def has_model(self, model_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a model exists in the routing table.
|
||||||
|
|
||||||
|
:param model_id: The model identifier to check
|
||||||
|
:return: True if the model exists, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await lookup_model(self, model_id)
|
||||||
|
return True
|
||||||
|
except ModelNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -97,6 +97,8 @@ class StreamingResponseOrchestrator:
|
||||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||||
# Track final messages after all tool executions
|
# Track final messages after all tool executions
|
||||||
self.final_messages: list[OpenAIMessageParam] = []
|
self.final_messages: list[OpenAIMessageParam] = []
|
||||||
|
# mapping for annotations
|
||||||
|
self.citation_files: dict[str, str] = {}
|
||||||
|
|
||||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# Initialize output messages
|
# Initialize output messages
|
||||||
|
@ -126,6 +128,7 @@ class StreamingResponseOrchestrator:
|
||||||
# Text is the default response format for chat completion so don't need to pass it
|
# Text is the default response format for chat completion so don't need to pass it
|
||||||
# (some providers don't support non-empty response_format when tools are present)
|
# (some providers don't support non-empty response_format when tools are present)
|
||||||
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
||||||
|
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||||
completion_result = await self.inference_api.openai_chat_completion(
|
completion_result = await self.inference_api.openai_chat_completion(
|
||||||
model=self.ctx.model,
|
model=self.ctx.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -160,7 +163,7 @@ class StreamingResponseOrchestrator:
|
||||||
# Handle choices with no tool calls
|
# Handle choices with no tool calls
|
||||||
for choice in current_response.choices:
|
for choice in current_response.choices:
|
||||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||||
output_messages.append(await convert_chat_choice_to_response_message(choice))
|
output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files))
|
||||||
|
|
||||||
# Execute tool calls and coordinate results
|
# Execute tool calls and coordinate results
|
||||||
async for stream_event in self._coordinate_tool_execution(
|
async for stream_event in self._coordinate_tool_execution(
|
||||||
|
@ -211,6 +214,8 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
for choice in current_response.choices:
|
for choice in current_response.choices:
|
||||||
next_turn_messages.append(choice.message)
|
next_turn_messages.append(choice.message)
|
||||||
|
logger.debug(f"Choice message content: {choice.message.content}")
|
||||||
|
logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")
|
||||||
|
|
||||||
if choice.message.tool_calls and self.ctx.response_tools:
|
if choice.message.tool_calls and self.ctx.response_tools:
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
|
@ -470,6 +475,8 @@ class StreamingResponseOrchestrator:
|
||||||
tool_call_log = result.final_output_message
|
tool_call_log = result.final_output_message
|
||||||
tool_response_message = result.final_input_message
|
tool_response_message = result.final_input_message
|
||||||
self.sequence_number = result.sequence_number
|
self.sequence_number = result.sequence_number
|
||||||
|
if result.citation_files:
|
||||||
|
self.citation_files.update(result.citation_files)
|
||||||
|
|
||||||
if tool_call_log:
|
if tool_call_log:
|
||||||
output_messages.append(tool_call_log)
|
output_messages.append(tool_call_log)
|
||||||
|
|
|
@ -94,7 +94,10 @@ class ToolExecutor:
|
||||||
|
|
||||||
# Yield the final result
|
# Yield the final result
|
||||||
yield ToolExecutionResult(
|
yield ToolExecutionResult(
|
||||||
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
|
sequence_number=sequence_number,
|
||||||
|
final_output_message=output_message,
|
||||||
|
final_input_message=input_message,
|
||||||
|
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_knowledge_search_via_vector_store(
|
async def _execute_knowledge_search_via_vector_store(
|
||||||
|
@ -129,8 +132,6 @@ class ToolExecutor:
|
||||||
for results in all_results:
|
for results in all_results:
|
||||||
search_results.extend(results)
|
search_results.extend(results)
|
||||||
|
|
||||||
# Convert search results to tool result format matching memory.py
|
|
||||||
# Format the results as interleaved content similar to memory.py
|
|
||||||
content_items = []
|
content_items = []
|
||||||
content_items.append(
|
content_items.append(
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
|
@ -138,27 +139,58 @@ class ToolExecutor:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
unique_files = set()
|
||||||
for i, result_item in enumerate(search_results):
|
for i, result_item in enumerate(search_results):
|
||||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||||
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
# Get file_id from attributes if result_item.file_id is empty
|
||||||
|
file_id = result_item.file_id or (
|
||||||
|
result_item.attributes.get("document_id") if result_item.attributes else None
|
||||||
|
)
|
||||||
|
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
|
||||||
if result_item.attributes:
|
if result_item.attributes:
|
||||||
metadata_text += f", attributes: {result_item.attributes}"
|
metadata_text += f", attributes: {result_item.attributes}"
|
||||||
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
|
||||||
|
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
|
||||||
content_items.append(TextContentItem(text=text_content))
|
content_items.append(TextContentItem(text=text_content))
|
||||||
|
unique_files.add(file_id)
|
||||||
|
|
||||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||||
|
|
||||||
|
citation_instruction = ""
|
||||||
|
if unique_files:
|
||||||
|
citation_instruction = (
|
||||||
|
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
|
||||||
|
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
|
||||||
|
)
|
||||||
|
|
||||||
content_items.append(
|
content_items.append(
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# handling missing attributes for old versions
|
||||||
|
citation_files = {}
|
||||||
|
for result in search_results:
|
||||||
|
file_id = result.file_id
|
||||||
|
if not file_id and result.attributes:
|
||||||
|
file_id = result.attributes.get("document_id")
|
||||||
|
|
||||||
|
filename = result.filename
|
||||||
|
if not filename and result.attributes:
|
||||||
|
filename = result.attributes.get("filename")
|
||||||
|
if not filename:
|
||||||
|
filename = "unknown"
|
||||||
|
|
||||||
|
citation_files[file_id] = filename
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=content_items,
|
content=content_items,
|
||||||
metadata={
|
metadata={
|
||||||
"document_ids": [r.file_id for r in search_results],
|
"document_ids": [r.file_id for r in search_results],
|
||||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||||
"scores": [r.score for r in search_results],
|
"scores": [r.score for r in search_results],
|
||||||
|
"citation_files": citation_files,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ class ToolExecutionResult(BaseModel):
|
||||||
sequence_number: int
|
sequence_number: int
|
||||||
final_output_message: OpenAIResponseOutput | None = None
|
final_output_message: OpenAIResponseOutput | None = None
|
||||||
final_input_message: OpenAIMessageParam | None = None
|
final_input_message: OpenAIMessageParam | None = None
|
||||||
|
citation_files: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -4,9 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseAnnotationFileCitation,
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputFunctionToolCallOutput,
|
OpenAIResponseInputFunctionToolCallOutput,
|
||||||
OpenAIResponseInputMessageContent,
|
OpenAIResponseInputMessageContent,
|
||||||
|
@ -45,7 +47,9 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
async def convert_chat_choice_to_response_message(
|
||||||
|
choice: OpenAIChoice, citation_files: dict[str, str] | None = None
|
||||||
|
) -> OpenAIResponseMessage:
|
||||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||||
output_content = ""
|
output_content = ""
|
||||||
if isinstance(choice.message.content, str):
|
if isinstance(choice.message.content, str):
|
||||||
|
@ -57,9 +61,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
|
||||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||||
|
|
||||||
return OpenAIResponseMessage(
|
return OpenAIResponseMessage(
|
||||||
id=f"msg_{uuid.uuid4()}",
|
id=f"msg_{uuid.uuid4()}",
|
||||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
||||||
status="completed",
|
status="completed",
|
||||||
role="assistant",
|
role="assistant",
|
||||||
)
|
)
|
||||||
|
@ -200,6 +206,53 @@ async def get_message_type_by_role(role: str):
|
||||||
return role_to_type.get(role)
|
return role_to_type.get(role)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_citations_from_text(
|
||||||
|
text: str, citation_files: dict[str, str]
|
||||||
|
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
|
||||||
|
"""Extract citation markers from text and create annotations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
|
||||||
|
citation_files: Dictionary mapping file_id to filename
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (annotations_list, clean_text_without_markers)
|
||||||
|
"""
|
||||||
|
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
parts = []
|
||||||
|
total_len = 0
|
||||||
|
last_end = 0
|
||||||
|
|
||||||
|
for m in file_id_regex.finditer(text):
|
||||||
|
# segment before the marker
|
||||||
|
prefix = text[last_end : m.start()]
|
||||||
|
|
||||||
|
# drop one space if it exists (since marker is at sentence end)
|
||||||
|
if prefix.endswith(" "):
|
||||||
|
prefix = prefix[:-1]
|
||||||
|
|
||||||
|
parts.append(prefix)
|
||||||
|
total_len += len(prefix)
|
||||||
|
|
||||||
|
fid = m.group(1)
|
||||||
|
if fid in citation_files:
|
||||||
|
annotations.append(
|
||||||
|
OpenAIResponseAnnotationFileCitation(
|
||||||
|
file_id=fid,
|
||||||
|
filename=citation_files[fid],
|
||||||
|
index=total_len, # index points to punctuation
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
last_end = m.end()
|
||||||
|
|
||||||
|
parts.append(text[last_end:])
|
||||||
|
cleaned_text = "".join(parts)
|
||||||
|
return annotations, cleaned_text
|
||||||
|
|
||||||
|
|
||||||
def is_function_tool_call(
|
def is_function_tool_call(
|
||||||
tool_call: OpenAIChatCompletionToolCall,
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
tools: list[OpenAIResponseInputTool],
|
tools: list[OpenAIResponseInputTool],
|
||||||
|
|
|
@ -331,5 +331,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=result.content or [],
|
content=result.content or [],
|
||||||
metadata=result.metadata,
|
metadata={
|
||||||
|
**(result.metadata or {}),
|
||||||
|
"citation_files": getattr(result, "citation_files", None),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -225,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
await self.initialize_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# Cleanup if needed
|
# Clean up mixin resources (file batch tasks)
|
||||||
pass
|
await super().shutdown()
|
||||||
|
|
||||||
async def health(self) -> HealthResponse:
|
async def health(self) -> HealthResponse:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -434,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
await self.initialize_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# nothing to do since we don't maintain a persistent connection
|
# Clean up mixin resources (file batch tasks)
|
||||||
pass
|
await super().shutdown()
|
||||||
|
|
||||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||||
return [v.vector_db for v in self.cache.values()]
|
return [v.vector_db for v in self.cache.values()]
|
||||||
|
|
|
@ -167,7 +167,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -349,6 +349,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -390,6 +390,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
||||||
if self.conn is not None:
|
if self.conn is not None:
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
log.info("Connection to PGVector database server closed")
|
log.info("Connection to PGVector database server closed")
|
||||||
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
# Persist vector DB metadata in the KV store
|
# Persist vector DB metadata in the KV store
|
||||||
|
|
|
@ -191,6 +191,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
await self.client.close()
|
await self.client.close()
|
||||||
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -347,6 +347,8 @@ class WeaviateVectorIOAdapter(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for client in self.client_cache.values():
|
for client in self.client_cache.values():
|
||||||
client.close()
|
client.close()
|
||||||
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -474,11 +474,17 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
|
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a specific model is available from the provider's /v1/models.
|
Check if a specific model is available from the provider's /v1/models or pre-registered.
|
||||||
|
|
||||||
:param model: The model identifier to check.
|
:param model: The model identifier to check.
|
||||||
:return: True if the model is available dynamically, False otherwise.
|
:return: True if the model is available dynamically or pre-registered, False otherwise.
|
||||||
"""
|
"""
|
||||||
|
# First check if the model is pre-registered in the model store
|
||||||
|
if hasattr(self, "model_store") and self.model_store:
|
||||||
|
if await self.model_store.has_model(model):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Then check the provider's dynamic model cache
|
||||||
if not self._model_cache:
|
if not self._model_cache:
|
||||||
await self.list_models()
|
await self.list_models()
|
||||||
return model in self._model_cache
|
return model in self._model_cache
|
||||||
|
|
|
@ -293,6 +293,19 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
await self._resume_incomplete_batches()
|
await self._resume_incomplete_batches()
|
||||||
self._last_file_batch_cleanup_time = 0
|
self._last_file_batch_cleanup_time = 0
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Clean up mixin resources including background tasks."""
|
||||||
|
# Cancel any running file batch tasks gracefully
|
||||||
|
if hasattr(self, "_file_batch_tasks"):
|
||||||
|
tasks_to_cancel = list(self._file_batch_tasks.items())
|
||||||
|
for _, task in tasks_to_cancel:
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||||
"""Delete chunks from a vector store."""
|
"""Delete chunks from a vector store."""
|
||||||
|
@ -587,7 +600,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
content = self._chunk_to_vector_store_content(chunk)
|
content = self._chunk_to_vector_store_content(chunk)
|
||||||
|
|
||||||
response_data_item = VectorStoreSearchResponse(
|
response_data_item = VectorStoreSearchResponse(
|
||||||
file_id=chunk.metadata.get("file_id", ""),
|
file_id=chunk.metadata.get("document_id", ""),
|
||||||
filename=chunk.metadata.get("filename", ""),
|
filename=chunk.metadata.get("filename", ""),
|
||||||
score=score,
|
score=score,
|
||||||
attributes=chunk.metadata,
|
attributes=chunk.metadata,
|
||||||
|
@ -746,12 +759,15 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
|
|
||||||
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
||||||
|
|
||||||
|
chunk_attributes = attributes.copy()
|
||||||
|
chunk_attributes["filename"] = file_response.filename
|
||||||
|
|
||||||
chunks = make_overlapped_chunks(
|
chunks = make_overlapped_chunks(
|
||||||
file_id,
|
file_id,
|
||||||
content,
|
content,
|
||||||
max_chunk_size_tokens,
|
max_chunk_size_tokens,
|
||||||
chunk_overlap_tokens,
|
chunk_overlap_tokens,
|
||||||
attributes,
|
chunk_attributes,
|
||||||
)
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
vector_store_file_object.status = "failed"
|
vector_store_file_object.status = "failed"
|
||||||
|
|
|
@ -16,10 +16,19 @@
|
||||||
|
|
||||||
set -Eeuo pipefail
|
set -Eeuo pipefail
|
||||||
|
|
||||||
CONTAINER_RUNTIME=${CONTAINER_RUNTIME:-docker}
|
if command -v podman &> /dev/null; then
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
CONTAINER_RUNTIME="podman"
|
||||||
|
elif command -v docker &> /dev/null; then
|
||||||
|
CONTAINER_RUNTIME="docker"
|
||||||
|
else
|
||||||
|
echo "🚨 Neither Podman nor Docker could be found"
|
||||||
|
echo "Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
echo "🚀 Setting up telemetry stack for Llama Stack using Podman..."
|
echo "🚀 Setting up telemetry stack for Llama Stack using $CONTAINER_RUNTIME..."
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
|
||||||
if ! command -v "$CONTAINER_RUNTIME" &> /dev/null; then
|
if ! command -v "$CONTAINER_RUNTIME" &> /dev/null; then
|
||||||
echo "🚨 $CONTAINER_RUNTIME could not be found"
|
echo "🚨 $CONTAINER_RUNTIME could not be found"
|
||||||
|
|
|
@ -201,6 +201,12 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
||||||
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
|
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
|
||||||
assert non_existent is None
|
assert non_existent is None
|
||||||
|
|
||||||
|
# Test has_model
|
||||||
|
assert await table.has_model("test_provider/test-model")
|
||||||
|
assert await table.has_model("test_provider/test-model-2")
|
||||||
|
assert not await table.has_model("non-existent-model")
|
||||||
|
assert not await table.has_model("test_provider/non-existent-model")
|
||||||
|
|
||||||
await table.unregister_model(model_id="test_provider/test-model")
|
await table.unregister_model(model_id="test_provider/test-model")
|
||||||
await table.unregister_model(model_id="test_provider/test-model-2")
|
await table.unregister_model(model_id="test_provider/test-model-2")
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseAnnotationFileCitation,
|
||||||
OpenAIResponseInputFunctionToolCallOutput,
|
OpenAIResponseInputFunctionToolCallOutput,
|
||||||
OpenAIResponseInputMessageContentImage,
|
OpenAIResponseInputMessageContentImage,
|
||||||
OpenAIResponseInputMessageContentText,
|
OpenAIResponseInputMessageContentText,
|
||||||
|
@ -35,6 +36,7 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||||
|
_extract_citations_from_text,
|
||||||
convert_chat_choice_to_response_message,
|
convert_chat_choice_to_response_message,
|
||||||
convert_response_content_to_chat_content,
|
convert_response_content_to_chat_content,
|
||||||
convert_response_input_to_chat_messages,
|
convert_response_input_to_chat_messages,
|
||||||
|
@ -340,3 +342,26 @@ class TestIsFunctionToolCall:
|
||||||
|
|
||||||
result = is_function_tool_call(tool_call, tools)
|
result = is_function_tool_call(tool_call, tools)
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractCitationsFromText:
|
||||||
|
def test_extract_citations_and_annotations(self):
|
||||||
|
text = "Start [not-a-file]. New source <|file-abc123|>. "
|
||||||
|
text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
|
||||||
|
file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}
|
||||||
|
|
||||||
|
annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)
|
||||||
|
|
||||||
|
expected_annotations = [
|
||||||
|
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
|
||||||
|
OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
|
||||||
|
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
|
||||||
|
]
|
||||||
|
expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."
|
||||||
|
|
||||||
|
assert cleaned_text == expected_clean_text
|
||||||
|
assert annotations == expected_annotations
|
||||||
|
# OpenAI cites at the end of the sentence
|
||||||
|
assert cleaned_text[expected_annotations[0].index] == "."
|
||||||
|
assert cleaned_text[expected_annotations[1].index] == "?"
|
||||||
|
assert cleaned_text[expected_annotations[2].index] == "!"
|
||||||
|
|
|
@ -44,11 +44,12 @@ def mixin():
|
||||||
config = RemoteInferenceProviderConfig()
|
config = RemoteInferenceProviderConfig()
|
||||||
mixin_instance = OpenAIMixinImpl(config=config)
|
mixin_instance = OpenAIMixinImpl(config=config)
|
||||||
|
|
||||||
# just enough to satisfy _get_provider_model_id calls
|
# Mock model_store with async methods
|
||||||
mock_model_store = MagicMock()
|
mock_model_store = AsyncMock()
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
mock_model.provider_resource_id = "test-provider-resource-id"
|
mock_model.provider_resource_id = "test-provider-resource-id"
|
||||||
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||||
|
mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override
|
||||||
mixin_instance.model_store = mock_model_store
|
mixin_instance.model_store = mock_model_store
|
||||||
|
|
||||||
return mixin_instance
|
return mixin_instance
|
||||||
|
@ -189,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability:
|
||||||
|
|
||||||
assert len(mixin._model_cache) == 3
|
assert len(mixin._model_cache) == 3
|
||||||
|
|
||||||
|
async def test_check_model_availability_with_pre_registered_model(
|
||||||
|
self, mixin, mock_client_with_models, mock_client_context
|
||||||
|
):
|
||||||
|
"""Test that check_model_availability returns True for pre-registered models in model_store"""
|
||||||
|
# Mock model_store.has_model to return True for a specific model
|
||||||
|
mock_model_store = AsyncMock()
|
||||||
|
mock_model_store.has_model = AsyncMock(return_value=True)
|
||||||
|
mixin.model_store = mock_model_store
|
||||||
|
|
||||||
|
# Test that pre-registered model is found without calling the provider's API
|
||||||
|
with mock_client_context(mixin, mock_client_with_models):
|
||||||
|
mock_client_with_models.models.list.assert_not_called()
|
||||||
|
assert await mixin.check_model_availability("pre-registered-model")
|
||||||
|
# Should not call the provider's list_models since model was found in store
|
||||||
|
mock_client_with_models.models.list.assert_not_called()
|
||||||
|
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
|
||||||
|
|
||||||
|
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
|
||||||
|
self, mixin, mock_client_with_models, mock_client_context
|
||||||
|
):
|
||||||
|
"""Test that check_model_availability falls back to provider when model not in store"""
|
||||||
|
# Mock model_store.has_model to return False
|
||||||
|
mock_model_store = AsyncMock()
|
||||||
|
mock_model_store.has_model = AsyncMock(return_value=False)
|
||||||
|
mixin.model_store = mock_model_store
|
||||||
|
|
||||||
|
# Test that it falls back to provider's model cache
|
||||||
|
with mock_client_context(mixin, mock_client_with_models):
|
||||||
|
mock_client_with_models.models.list.assert_not_called()
|
||||||
|
assert await mixin.check_model_availability("some-mock-model-id")
|
||||||
|
# Should call the provider's list_models since model was not found in store
|
||||||
|
mock_client_with_models.models.list.assert_called_once()
|
||||||
|
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIMixinCacheBehavior:
|
class TestOpenAIMixinCacheBehavior:
|
||||||
"""Test cases for cache behavior and edge cases"""
|
"""Test cases for cache behavior and edge cases"""
|
||||||
|
|
|
@ -145,10 +145,10 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
|
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||||
config = SQLiteVectorIOConfig(
|
config = SQLiteVectorIOConfig(
|
||||||
db_path=sqlite_vec_db_path,
|
db_path=sqlite_vec_db_path,
|
||||||
kvstore=SqliteKVStoreConfig(),
|
kvstore=unique_kvstore_config,
|
||||||
)
|
)
|
||||||
adapter = SQLiteVecVectorIOAdapter(
|
adapter = SQLiteVecVectorIOAdapter(
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -187,10 +187,10 @@ async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api):
|
||||||
config = MilvusVectorIOConfig(
|
config = MilvusVectorIOConfig(
|
||||||
db_path=milvus_vec_db_path,
|
db_path=milvus_vec_db_path,
|
||||||
kvstore=SqliteKVStoreConfig(),
|
kvstore=unique_kvstore_config,
|
||||||
)
|
)
|
||||||
adapter = MilvusVectorIOAdapter(
|
adapter = MilvusVectorIOAdapter(
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -264,10 +264,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
|
async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||||
config = ChromaVectorIOConfig(
|
config = ChromaVectorIOConfig(
|
||||||
db_path=chroma_vec_db_path,
|
db_path=chroma_vec_db_path,
|
||||||
kvstore=SqliteKVStoreConfig(),
|
kvstore=unique_kvstore_config,
|
||||||
)
|
)
|
||||||
adapter = ChromaVectorIOAdapter(
|
adapter = ChromaVectorIOAdapter(
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -296,12 +296,12 @@ def qdrant_vec_db_path(tmp_path_factory):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
|
async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
config = QdrantVectorIOConfig(
|
config = QdrantVectorIOConfig(
|
||||||
db_path=qdrant_vec_db_path,
|
db_path=qdrant_vec_db_path,
|
||||||
kvstore=SqliteKVStoreConfig(),
|
kvstore=unique_kvstore_config,
|
||||||
)
|
)
|
||||||
adapter = QdrantVectorIOAdapter(
|
adapter = QdrantVectorIOAdapter(
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -386,14 +386,14 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||||
config = PGVectorVectorIOConfig(
|
config = PGVectorVectorIOConfig(
|
||||||
host="localhost",
|
host="localhost",
|
||||||
port=5432,
|
port=5432,
|
||||||
db="test_db",
|
db="test_db",
|
||||||
user="test_user",
|
user="test_user",
|
||||||
password="test_password",
|
password="test_password",
|
||||||
kvstore=SqliteKVStoreConfig(),
|
kvstore=unique_kvstore_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
||||||
|
@ -476,7 +476,7 @@ async def weaviate_vec_index(weaviate_vec_db_path):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
|
async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||||
import pytest_socket
|
import pytest_socket
|
||||||
import weaviate
|
import weaviate
|
||||||
|
|
||||||
|
@ -492,7 +492,7 @@ async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embeddi
|
||||||
config = WeaviateVectorIOConfig(
|
config = WeaviateVectorIOConfig(
|
||||||
weaviate_cluster_url="localhost:8080",
|
weaviate_cluster_url="localhost:8080",
|
||||||
weaviate_api_key=None,
|
weaviate_api_key=None,
|
||||||
kvstore=SqliteKVStoreConfig(),
|
kvstore=unique_kvstore_config,
|
||||||
)
|
)
|
||||||
adapter = WeaviateVectorIOAdapter(
|
adapter = WeaviateVectorIOAdapter(
|
||||||
config=config,
|
config=config,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue