Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-08-18 16:11:36 +09:00 committed by GitHub
commit c66ebae9b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
207 changed files with 15490 additions and 7927 deletions

View file

@ -48,8 +48,8 @@ from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = logging.getLogger()
@ -327,10 +327,21 @@ class MetaReferenceAgentsImpl(Agents):
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters
input,
model,
instructions,
previous_response_id,
store,
stream,
temperature,
text,
tools,
include,
max_infer_iters,
)
async def list_openai_responses(

View file

@ -1,880 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import time
import uuid
from collections.abc import AsyncIterator
from typing import Any
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
WebSearchToolTypes,
)
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatParam,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _convert_response_content_to_chat_content(
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
The content schemas of each API look similar, but are not exactly the same.
"""
if isinstance(content, str):
return content
converted_parts = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
if content_part.image_url:
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
elif isinstance(content_part, str):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
)
return converted_parts
async def _convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
else:
content = await _convert_response_content_to_chat_content(input_item.content)
message_type = await _get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
"""
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
"""
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
role="assistant",
)
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam:
"""
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
"""
if not text.format or text.format["type"] == "text":
return OpenAIResponseFormatText(type="text")
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
async def _get_message_type_by_role(role: str):
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: ListOpenAIResponseInputItem
response: OpenAIResponseObject
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
response_tools: list[OpenAIResponseInputTool] | None = None
chat_tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
temperature: float | None
response_format: OpenAIResponseFormatParam
class OpenAIResponsesImpl:
def __init__(
self,
inference_api: Inference,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
):
if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
# previous response input items
new_input_items = previous_response_with_input.input
# previous response output items
new_input_items.extend(previous_response_with_input.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
input = new_input_items
return input
async def _prepend_instructions(self, messages, instructions):
if instructions:
messages.insert(0, OpenAISystemMessageParam(content=instructions))
async def get_openai_response(
self,
response_id: str,
) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id)
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
return await self.responses_store.list_responses(after, limit, model, order)
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned.
:param order: The order to return the input items in.
:returns: An ListOpenAIResponseInputItem.
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def _store_response(
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
stream_gen = self._create_streaming_response(
input=input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
temperature=temperature,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
if response is None:
raise ValueError("The response stream never completed")
return response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
output_messages: list[OpenAIResponseOutput] = []
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Structured outputs
response_format = await _convert_response_text_to_chat_response_format(text)
# Tool setup, TODO: refactor this slightly since this can also yield events
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
response_tools=tools,
chat_tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
temperature=temperature,
response_format=response_format,
)
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
initial_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="in_progress",
output=output_messages.copy(),
text=text,
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
n_iter = 0
messages = ctx.messages.copy()
while True:
completion_result = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=messages,
tools=ctx.chat_tools,
stream=True,
temperature=ctx.temperature,
response_format=ctx.response_format,
)
# Process streaming chunks and build complete response
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
sequence_number = 0
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in completion_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
if tool_call.function.arguments:
# Guard against an initial None argument before we concatenate
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
current_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
function_tool_calls = []
non_function_tool_calls = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
next_turn_messages.append(choice.message)
if choice.message.tool_calls and tools:
for tool_call in choice.message.tool_calls:
if _is_function_tool_call(tool_call, tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
# execute non-function tool calls
for tool_call in non_function_tool_calls:
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if tool_response_message:
next_turn_messages.append(tool_response_message)
for tool_call in function_tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
break
messages = next_turn_messages
# Create final response
final_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="completed",
text=text,
output=output_messages,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
if store:
await self._store_response(
response=final_response,
input=input,
)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
return await self.responses_store.delete_response_object(response_id)
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
) -> tuple[
list[ChatCompletionToolParam],
dict[str, OpenAIResponseInputToolMCP],
OpenAIResponseOutput | None,
]:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
)
from llama_stack.apis.tools import Tool
mcp_tool_to_server = {}
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
return convert_tooldef_to_openai_tool(tool_def)
mcp_list_message = None
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "function":
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "file_search":
tool_name = "knowledge_search"
tool = await self.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "mcp":
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
always_allowed = None
never_allowed = None
if input_tool.allowed_tools:
if isinstance(input_tool.allowed_tools, list):
always_allowed = input_tool.allowed_tools
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
always_allowed = input_tool.allowed_tools.always
never_allowed = input_tool.allowed_tools.never
tool_defs = await list_mcp_tools(
endpoint=input_tool.server_url,
headers=input_tool.headers or {},
)
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
status="completed",
server_label=input_tool.server_label,
tools=[],
)
for t in tool_defs.data:
if never_allowed and t.name in never_allowed:
continue
if not always_allowed or t.name in always_allowed:
chat_tools.append(make_openai_tool(t.name, t))
if t.name in mcp_tool_to_server:
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
mcp_tool_to_server[t.name] = input_tool
mcp_list_message.tools.append(
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
},
)
)
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_knowledge_search_via_vector_store(
self,
query: str,
response_file_search_tool: OpenAIResponseInputToolFileSearch,
) -> ToolInvocationResult:
"""Execute knowledge search using vector_stores.search API with filters support."""
search_results = []
# Create search tasks for all vector stores
async def search_single_store(vector_store_id):
try:
search_response = await self.vector_io_api.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=response_file_search_tool.filters,
max_num_results=response_file_search_tool.max_num_results,
ranking_options=response_file_search_tool.ranking_options,
rewrite_query=False,
)
return search_response.data
except Exception as e:
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
return []
# Run all searches in parallel using gather
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
all_results = await asyncio.gather(*search_tasks)
# Flatten results
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
)
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
)
)
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
},
)
async def _execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
tool_call_id = tool_call.id
function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
if not function or not tool_call_id or not function.name:
return None, None
error_exc = None
result = None
try:
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
mcp_tool = ctx.mcp_tool_to_server[function.name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function.name,
kwargs=tool_kwargs,
)
elif function.name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=tool_kwargs,
)
except Exception as e:
error_exc = e
if function.name in ctx.mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
arguments=function.arguments,
name=function.name,
server_label=ctx.mcp_tool_to_server[function.name].server_label,
)
if error_exc:
message.error = str(error_exc)
elif (result.error_code and result.error_code > 0) or result.error_message:
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
id=tool_call_id,
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if "document_ids" in result.metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
message.results.append(
{
"file_id": doc_id,
"filename": doc_id,
"text": text,
"score": score,
}
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
content = []
for item in result.content:
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
else:
text = str(error_exc)
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message
def _is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False

View file

@ -191,7 +191,11 @@ class AgentPersistence:
sessions = []
for value in values:
try:
session_info = Session(**json.loads(value))
data = json.loads(value)
if "turn_id" in data:
continue
session_info = Session(**data)
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,271 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from collections.abc import AsyncIterator
from pydantic import BaseModel
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference import (
Inference,
OpenAISystemMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext
from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
)
logger = get_logger(name=__name__, category="responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: ListOpenAIResponseInputItem
response: OpenAIResponseObject
class OpenAIResponsesImpl:
def __init__(
self,
inference_api: Inference,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api,
vector_io_api=vector_io_api,
)
async def _prepend_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None = None,
):
if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
# previous response input items
new_input_items = previous_response_with_input.input
# previous response output items
new_input_items.extend(previous_response_with_input.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
input = new_input_items
return input
async def _prepend_instructions(self, messages, instructions):
if instructions:
messages.insert(0, OpenAISystemMessageParam(content=instructions))
async def get_openai_response(
self,
response_id: str,
) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id)
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
return await self.responses_store.list_responses(after, limit, model, order)
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned.
:param order: The order to return the input items in.
:returns: An ListOpenAIResponseInputItem.
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def _store_response(
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
stream_gen = self._create_streaming_response(
input=input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
temperature=temperature,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
if response is None:
raise ValueError("The response stream never completed")
return response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Structured outputs
response_format = await convert_response_text_to_chat_response_format(text)
ctx = ChatCompletionContext(
model=model,
messages=messages,
response_tools=tools,
temperature=temperature,
response_format=response_format,
)
# Create orchestrator and delegate streaming logic
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
orchestrator = StreamingResponseOrchestrator(
inference_api=self.inference_api,
ctx=ctx,
response_id=response_id,
created_at=created_at,
text=text,
max_infer_iters=max_infer_iters,
tool_executor=self.tool_executor,
)
# Stream the response
final_response = None
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type == "response.completed":
final_response = stream_chunk.response
yield stream_chunk
# Store the response if requested
if store and final_response:
await self._store_response(
response=final_response,
input=input,
)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
return await self.responses_store.delete_response_object(response_id)

View file

@ -0,0 +1,634 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseContentPartAdded,
OpenAIResponseObjectStreamResponseContentPartDone,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
OpenAIResponseObjectStreamResponseMcpListToolsInProgress,
OpenAIResponseObjectStreamResponseOutputItemAdded,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
WebSearchToolTypes,
)
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionToolCall,
OpenAIChoice,
)
from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses")
class StreamingResponseOrchestrator:
def __init__(
self,
inference_api: Inference,
ctx: ChatCompletionContext,
response_id: str,
created_at: int,
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
):
self.inference_api = inference_api
self.ctx = ctx
self.response_id = response_id
self.created_at = created_at
self.text = text
self.max_infer_iters = max_infer_iters
self.tool_executor = tool_executor
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
output_messages: list[OpenAIResponseOutput] = []
# Create initial response and emit response.created immediately
initial_response = OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="in_progress",
output=output_messages.copy(),
text=self.text,
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# Process all tools (including MCP tools) and emit streaming events
if self.ctx.response_tools:
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
yield stream_event
n_iter = 0
messages = self.ctx.messages.copy()
while True:
completion_result = await self.inference_api.openai_chat_completion(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
stream=True,
temperature=self.ctx.temperature,
response_format=self.ctx.response_format,
)
# Process streaming chunks and build complete response
completion_result_data = None
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
if isinstance(stream_event_or_result, ChatCompletionResult):
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data)
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
current_response, messages
)
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(await convert_chat_choice_to_response_message(choice))
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
function_tool_calls,
non_function_tool_calls,
completion_result_data,
output_messages,
next_turn_messages,
):
yield stream_event
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= self.max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
break
messages = next_turn_messages
# Create final response
final_response = OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="completed",
text=self.text,
output=output_messages,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
"""Separate tool calls into function and non-function categories."""
function_tool_calls = []
non_function_tool_calls = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
next_turn_messages.append(choice.message)
if choice.message.tool_calls and self.ctx.response_tools:
for tool_call in choice.message.tool_calls:
if is_function_tool_call(tool_call, self.ctx.response_tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, next_turn_messages
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
"""Process streaming chunks and emit events, returning completion data."""
# Initialize result tracking
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
# Track tool call items for streaming events
tool_call_item_ids: dict[int, str] = {}
# Track content parts for streaming events
content_part_emitted = False
async for chunk in completion_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
# Emit content_part.added event for first text chunk
if not content_part_emitted:
content_part_emitted = True
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
response_id=self.response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartOutputText(
text="", # Will be filled incrementally via text deltas
),
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=self.sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
# Create new tool call entry if this is the first chunk for this index
is_new_tool_call = response_tool_call is None
if is_new_tool_call:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Create item ID for this tool call for streaming events
tool_call_item_id = f"fc_{uuid.uuid4()}"
tool_call_item_ids[tool_call.index] = tool_call_item_id
# Emit output_item.added event for the new function call
self.sequence_number += 1
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments="", # Will be filled incrementally via delta events
call_id=tool_call.id or "",
name=tool_call.function.name if tool_call.function else "",
id=tool_call_item_id,
status="in_progress",
)
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=function_call_item,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
if tool_call.function and tool_call.function.arguments:
tool_call_item_id = tool_call_item_ids[tool_call.index]
self.sequence_number += 1
# Check if this is an MCP tool call
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
if is_mcp_tool:
# Emit MCP-specific argument delta event
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
else:
# Emit function call argument delta event
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call:
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
# Check if this is an MCP tool call
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
self.sequence_number += 1
done_event_cls = (
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
if is_mcp_tool
else OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone
)
yield done_event_cls(
arguments=final_arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Emit content_part.done event if text content was streamed (before content gets cleared)
if content_part_emitted:
final_text = "".join(chat_response_content)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
response_id=self.response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartOutputText(
text=final_text,
),
sequence_number=self.sequence_number,
)
# Clear content when there are tool calls (OpenAI spec behavior)
if chat_response_tool_calls:
chat_response_content = []
yield ChatCompletionResult(
response_id=chat_response_id,
content=chat_response_content,
tool_calls=chat_response_tool_calls,
created=chunk_created,
model=chunk_model,
finish_reason=chunk_finish_reason,
message_item_id=message_item_id,
tool_call_item_ids=tool_call_item_ids,
content_part_emitted=content_part_emitted,
)
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
"""Build OpenAIChatCompletion from ChatCompletionResult."""
# Convert collected chunks to complete response
if result.tool_calls:
tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content=result.content_text,
tool_calls=tool_calls,
)
return OpenAIChatCompletion(
id=result.response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=result.finish_reason,
index=0,
)
],
created=result.created,
model=result.model,
)
async def _coordinate_tool_execution(
self,
function_tool_calls: list,
non_function_tool_calls: list,
completion_result_data: ChatCompletionResult,
output_messages: list[OpenAIResponseOutput],
next_turn_messages: list,
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Coordinate execution of both function and non-function tool calls."""
# Execute non-function tool calls
for tool_call in non_function_tool_calls:
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items():
response_tool_call = completion_result_data.tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
# Use a fallback item_id if not found
if not matching_item_id:
matching_item_id = f"tc_{uuid.uuid4()}"
# Execute tool call with streaming
tool_call_log = None
tool_response_message = None
async for result in self.tool_executor.execute_tool_call(
tool_call,
self.ctx,
self.sequence_number,
len(output_messages),
matching_item_id,
self.mcp_tool_to_server,
):
if result.stream_event:
# Forward streaming events
self.sequence_number = result.sequence_number
yield result.stream_event
if result.final_output_message is not None:
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
self.sequence_number = result.sequence_number
if tool_call_log:
output_messages.append(tool_call_log)
# Emit output_item.done event for completed non-function tool call
if matching_item_id:
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=tool_call_log,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
if tool_response_message:
next_turn_messages.append(tool_response_message)
# Execute function tool calls (client-side)
for tool_call in function_tool_calls:
# Find the item_id for this tool call from our tracking dictionary
matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items():
response_tool_call = completion_result_data.tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
# Use existing item_id or create new one if not found
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=final_item_id,
status="completed",
)
output_messages.append(function_call_item)
# Emit output_item.done event for completed function call
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=function_call_item,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async def _process_tools(
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Process all tools and emit appropriate streaming events."""
from openai.types.chat import ChatCompletionToolParam
from llama_stack.apis.tools import Tool
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
return convert_tooldef_to_openai_tool(tool_def)
# Initialize chat_tools if not already set
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
for input_tool in tools:
if input_tool.type == "function":
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
# Need to access tool_groups_api from tool_executor
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "file_search":
tool_name = "knowledge_search"
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "mcp":
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
yield stream_event
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
async def _process_mcp_tool(
self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Process an MCP tool configuration and emit appropriate streaming events."""
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
# Emit mcp_list_tools.in_progress
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
sequence_number=self.sequence_number,
)
try:
# Parse allowed/never allowed tools
always_allowed = None
never_allowed = None
if mcp_tool.allowed_tools:
if isinstance(mcp_tool.allowed_tools, list):
always_allowed = mcp_tool.allowed_tools
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
always_allowed = mcp_tool.allowed_tools.always
never_allowed = mcp_tool.allowed_tools.never
# Call list_mcp_tools
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
)
# Create the MCP list tools message
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
server_label=mcp_tool.server_label,
tools=[],
)
# Process tools and update context
for t in tool_defs.data:
if never_allowed and t.name in never_allowed:
continue
if not always_allowed or t.name in always_allowed:
# Add to chat tools for inference
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
tool_def = ToolDefinition(
tool_name=t.name,
description=t.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in t.parameters
},
)
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
# Add to MCP tool mapping
if t.name in self.mcp_tool_to_server:
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
self.mcp_tool_to_server[t.name] = mcp_tool
# Add to MCP list message
mcp_list_message.tools.append(
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
},
)
)
# Add the MCP list message to output
output_messages.append(mcp_list_message)
# Emit output_item.added for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
# Emit mcp_list_tools.completed
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
sequence_number=self.sequence_number,
)
# Emit output_item.done for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
except Exception as e:
# TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise

View file

@ -0,0 +1,379 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from collections.abc import AsyncIterator
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseObjectStreamResponseMcpCallCompleted,
OpenAIResponseObjectStreamResponseMcpCallFailed,
OpenAIResponseObjectStreamResponseMcpCallInProgress,
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFileSearchToolCallResults,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIImageURL,
OpenAIToolMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses")
class ToolExecutor:
def __init__(
self,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
vector_io_api: VectorIO,
):
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.vector_io_api = vector_io_api
async def execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
tool_call_id = tool_call.id
function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
if not function or not tool_call_id or not function.name:
yield ToolExecutionResult(sequence_number=sequence_number)
return
# Emit progress events for tool execution start
async for event_result in self._emit_progress_events(
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
):
sequence_number = event_result.sequence_number
yield event_result
# Execute the actual tool call
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
# Emit completion events for tool execution
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
async for event_result in self._emit_completion_events(
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
):
sequence_number = event_result.sequence_number
yield event_result
# Build result messages from tool execution
output_message, input_message = await self._build_result_messages(
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
)
# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
)
async def _execute_knowledge_search_via_vector_store(
self,
query: str,
response_file_search_tool: OpenAIResponseInputToolFileSearch,
) -> ToolInvocationResult:
"""Execute knowledge search using vector_stores.search API with filters support."""
search_results = []
# Create search tasks for all vector stores
async def search_single_store(vector_store_id):
try:
search_response = await self.vector_io_api.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=response_file_search_tool.filters,
max_num_results=response_file_search_tool.max_num_results,
ranking_options=response_file_search_tool.ranking_options,
rewrite_query=False,
)
return search_response.data
except Exception as e:
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
return []
# Run all searches in parallel using gather
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
all_results = await asyncio.gather(*search_tasks)
# Flatten results
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
)
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
)
)
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
},
)
async def _emit_progress_events(
self,
function_name: str,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
"""Emit progress events for tool execution start."""
# Emit in_progress event based on tool type (only for tools with specific streaming events)
progress_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
elif function_name == "web_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
# For web search, emit searching event
if function_name == "web_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
async def _execute_tool(
self,
function_name: str,
tool_kwargs: dict,
ctx: ChatCompletionContext,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[Exception | None, any]:
"""Execute the tool and return error exception and result."""
error_exc = None
result = None
try:
if mcp_tool_to_server and function_name in mcp_tool_to_server:
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
mcp_tool = mcp_tool_to_server[function_name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
except Exception as e:
error_exc = e
return error_exc, result
async def _emit_completion_events(
self,
function_name: str,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
"""Emit completion or failure events for tool execution."""
completion_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
if has_error:
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number,
)
else:
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number,
)
elif function_name == "web_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
async def _build_result_messages(
self,
function,
tool_call_id: str,
tool_kwargs: dict,
ctx: ChatCompletionContext,
error_exc: Exception | None,
result: any,
has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[any, any]:
"""Build output and input messages from tool execution results."""
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
# Build output message
if mcp_tool_to_server and function.name in mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
arguments=function.arguments,
name=function.name,
server_label=mcp_tool_to_server[function.name].server_label,
)
if error_exc:
message.error = str(error_exc)
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result and result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
)
if has_error:
message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
id=tool_call_id,
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if result and "document_ids" in result.metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
message.results.append(
OpenAIResponseOutputMessageFileSearchToolCallResults(
file_id=doc_id,
filename=doc_id,
text=text,
score=score,
attributes={},
)
)
if has_error:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
# Build input message
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
content = []
for item in result.content:
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
else:
text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputTool,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
)
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
class ToolExecutionResult(BaseModel):
"""Result of streaming tool execution."""
stream_event: OpenAIResponseObjectStream | None = None
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
@dataclass
class ChatCompletionResult:
"""Result of processing streaming chat completion chunks."""
response_id: str
content: list[str]
tool_calls: dict[int, OpenAIChatCompletionToolCall]
created: int
model: str
finish_reason: str
message_item_id: str # For streaming events
tool_call_item_ids: dict[int, str] # For streaming events
content_part_emitted: bool # Tracking state
@property
def content_text(self) -> str:
"""Get joined content as string."""
return "".join(self.content)
@property
def has_tool_calls(self) -> bool:
"""Check if there are any tool calls."""
return bool(self.tool_calls)
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
response_tools: list[OpenAIResponseInputTool] | None = None
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam

View file

@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseText,
)
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatParam,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
role="assistant",
)
async def convert_response_content_to_chat_content(
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
The content schemas of each API look similar, but are not exactly the same.
"""
if isinstance(content, str):
return content
converted_parts = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
if content_part.image_url:
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
elif isinstance(content_part, str):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
)
return converted_parts
async def convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
async def convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:
"""
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
"""
if not text.format or text.format["type"] == "text":
return OpenAIResponseFormatText(type="text")
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
async def get_message_type_by_role(role: str):
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
def is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.core.datatypes import AccessRule, Api
from llama_stack.providers.utils.kvstore import kvstore_impl
from .batches import ReferenceBatchesImpl
from .config import ReferenceBatchesImplConfig
__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
kvstore = await kvstore_impl(config.kvstore)
inference_api: Inference | None = deps.get(Api.inference)
files_api: Files | None = deps.get(Api.files)
models_api: Models | None = deps.get(Api.models)
if inference_api is None:
raise ValueError("Inference API is required but not provided in dependencies")
if files_api is None:
raise ValueError("Files API is required but not provided in dependencies")
if models_api is None:
raise ValueError("Models API is required but not provided in dependencies")
impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
await impl.initialize()
return impl

View file

@ -0,0 +1,580 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import itertools
import json
import time
import uuid
from io import BytesIO
from typing import Any, Literal
from openai.types.batch import BatchError, Errors
from pydantic import BaseModel
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.models import Models
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from .config import ReferenceBatchesImplConfig
BATCH_PREFIX = "batch:"
logger = get_logger(__name__)
class AsyncBytesIO:
"""
Async-compatible BytesIO wrapper to allow async file-like operations.
We use this when uploading files to the Files API, as it expects an
async file-like object.
"""
def __init__(self, data: bytes):
self._buffer = BytesIO(data)
async def read(self, n=-1):
return self._buffer.read(n)
async def seek(self, pos, whence=0):
return self._buffer.seek(pos, whence)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._buffer.close()
def __getattr__(self, name):
return getattr(self._buffer, name)
class BatchRequest(BaseModel):
line_num: int
custom_id: str
method: str
url: str
body: dict[str, Any]
def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
"""Convert a message dictionary to OpenAIMessageParam based on role."""
role = msg.get("role")
if role == "user":
return OpenAIUserMessageParam(**msg)
elif role == "system":
return OpenAISystemMessageParam(**msg)
elif role == "assistant":
return OpenAIAssistantMessageParam(**msg)
elif role == "tool":
return OpenAIToolMessageParam(**msg)
elif role == "developer":
return OpenAIDeveloperMessageParam(**msg)
else:
raise ValueError(f"Unknown message role: {role}")
class ReferenceBatchesImpl(Batches):
"""Reference implementation of the Batches API.
This implementation processes batch files by making individual requests
to the inference API and generates output files with results.
"""
def __init__(
self,
config: ReferenceBatchesImplConfig,
inference_api: Inference,
files_api: Files,
models_api: Models,
kvstore: KVStore,
) -> None:
self.config = config
self.kvstore = kvstore
self.inference_api = inference_api
self.files_api = files_api
self.models_api = models_api
self._processing_tasks: dict[str, asyncio.Task] = {}
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
self._update_batch_lock = asyncio.Lock()
# this is to allow tests to disable background processing
self.process_batches = True
async def initialize(self) -> None:
# TODO: start background processing of existing tasks
pass
async def shutdown(self) -> None:
"""Shutdown the batches provider."""
if self._processing_tasks:
# don't cancel tasks - just let them stop naturally on shutdown
# cancelling would mark batches as "cancelled" in the database
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
) -> BatchObject:
"""
Create a new batch for processing multiple API requests.
Error handling by levels -
0. Input param handling, results in 40x errors before processing, e.g.
- Wrong completion_window
- Invalid metadata types
- Unknown endpoint
-> no batch created
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
- input_file_id missing
- invalid json in file
- missing custom_id, method, url, body
- invalid model
- streaming
-> batch created, validation sends to failed status
2. Processing errors, result in error_file_id entries, e.g.
- Any error returned from inference endpoint
-> batch created, goes to completed status
"""
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
raise ValueError(
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
)
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
current_time = int(time.time())
batch = BatchObject(
id=batch_id,
object="batch",
endpoint=endpoint,
input_file_id=input_file_id,
completion_window=completion_window,
status="validating",
created_at=current_time,
metadata=metadata,
)
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
if self.process_batches:
task = asyncio.create_task(self._process_batch(batch_id))
self._processing_tasks[batch_id] = task
return batch
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress."""
batch = await self.retrieve_batch(batch_id)
if batch.status in ["cancelled", "cancelling"]:
return batch
if batch.status in ["completed", "failed", "expired"]:
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
if batch_id in self._processing_tasks:
self._processing_tasks[batch_id].cancel()
# note: task removal and status="cancelled" handled in finally block of _process_batch
return await self.retrieve_batch(batch_id)
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
"""
List all batches, eventually only for the current user.
With no notion of user, we return all batches.
"""
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
batches = []
for batch_data in batch_values:
if batch_data:
batches.append(BatchObject.model_validate_json(batch_data))
batches.sort(key=lambda b: b.created_at, reverse=True)
start_idx = 0
if after:
for i, batch in enumerate(batches):
if batch.id == after:
start_idx = i + 1
break
page_batches = batches[start_idx : start_idx + limit]
has_more = (start_idx + limit) < len(batches)
first_id = page_batches[0].id if page_batches else None
last_id = page_batches[-1].id if page_batches else None
return ListBatchesResponse(
data=page_batches,
first_id=first_id,
last_id=last_id,
has_more=has_more,
)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch."""
batch_data = await self.kvstore.get(f"batch:{batch_id}")
if not batch_data:
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
return BatchObject.model_validate_json(batch_data)
async def _update_batch(self, batch_id: str, **updates) -> None:
"""Update batch fields in kvstore."""
async with self._update_batch_lock:
try:
batch = await self.retrieve_batch(batch_id)
# batch processing is async. once cancelling, only allow "cancelled" status updates
if batch.status == "cancelling" and updates.get("status") != "cancelled":
logger.info(
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
)
return
if "errors" in updates:
updates["errors"] = updates["errors"].model_dump()
batch_dict = batch.model_dump()
batch_dict.update(updates)
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
except Exception as e:
logger.error(f"Failed to update batch {batch_id}: {e}")
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
"""
Read & validate input, return errors and valid input.
Validation of
- input_file_id existance
- valid json
- custom_id, method, url, body presence and valid
- no streaming
"""
requests: list[BatchRequest] = []
errors: list[BatchError] = []
try:
await self.files_api.openai_retrieve_file(batch.input_file_id)
except Exception:
errors.append(
BatchError(
code="invalid_request",
line=None,
message=f"Cannot find file {batch.input_file_id}.",
param="input_file_id",
)
)
return errors, requests
# TODO(SECURITY): do something about large files
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
file_content = file_content_response.body.decode("utf-8")
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
if line.strip(): # skip empty lines
try:
request = json.loads(line)
if not isinstance(request, dict):
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message="Each line must be a JSON dictionary object",
)
)
continue
valid = True
for param, expected_type, type_string in [
("custom_id", str, "string"),
("method", str, "string"),
("url", str, "string"),
("body", dict, "JSON dictionary object"),
]:
if param not in request:
errors.append(
BatchError(
code="missing_required_parameter",
line=line_num,
message=f"Missing required parameter: {param}",
param=param,
)
)
valid = False
elif not isinstance(request[param], expected_type):
param_name = "URL" if param == "url" else param.capitalize()
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param_name} must be a {type_string}",
param=param,
)
)
valid = False
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
errors.append(
BatchError(
code="invalid_url",
line=line_num,
message="URL provided for this request does not match the batch endpoint",
param="url",
)
)
valid = False
if (body := request.get("body")) and isinstance(body, dict):
if body.get("stream", False):
errors.append(
BatchError(
code="streaming_unsupported",
line=line_num,
message="Streaming is not supported in batch processing",
param="body.stream",
)
)
valid = False
for param, expected_type, type_string in [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]:
if param not in body:
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param.capitalize()} parameter is required",
param=f"body.{param}",
)
)
valid = False
elif not isinstance(body[param], expected_type):
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param.capitalize()} must be {type_string}",
param=f"body.{param}",
)
)
valid = False
if "model" in body and isinstance(body["model"], str):
try:
await self.models_api.get_model(body["model"])
except Exception:
errors.append(
BatchError(
code="model_not_found",
line=line_num,
message=f"Model '{body['model']}' does not exist or is not supported",
param="body.model",
)
)
valid = False
if valid:
assert isinstance(url, str), "URL must be a string" # for mypy
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
requests.append(
BatchRequest(
line_num=line_num,
url=url,
method=request["method"],
custom_id=request["custom_id"],
body=body,
),
)
except json.JSONDecodeError:
errors.append(
BatchError(
code="invalid_json_line",
line=line_num,
message="This line is not parseable as valid JSON.",
)
)
return errors, requests
async def _process_batch(self, batch_id: str) -> None:
"""Background task to process a batch of requests."""
try:
logger.info(f"Starting batch processing for {batch_id}")
async with self._batch_semaphore: # semaphore to limit concurrency
logger.info(f"Acquired semaphore for batch {batch_id}")
await self._process_batch_impl(batch_id)
except asyncio.CancelledError:
logger.info(f"Batch processing cancelled for {batch_id}")
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
except Exception as e:
logger.error(f"Batch processing failed for {batch_id}: {e}")
await self._update_batch(
batch_id,
status="failed",
failed_at=int(time.time()),
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
)
finally:
self._processing_tasks.pop(batch_id, None)
async def _process_batch_impl(self, batch_id: str) -> None:
"""Implementation of batch processing logic."""
errors: list[BatchError] = []
batch = await self.retrieve_batch(batch_id)
errors, requests = await self._validate_input(batch)
if errors:
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
return
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
total_requests = len(requests)
await self._update_batch(
batch_id,
status="in_progress",
request_counts={"total": total_requests, "completed": 0, "failed": 0},
)
error_results = []
success_results = []
completed_count = 0
failed_count = 0
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
async with asyncio.TaskGroup() as tg:
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
for result in chunk_results:
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
failed_count += 1
error_results.append(result)
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
completed_count += 1
success_results.append(result)
else: # unexpected result
failed_count += 1
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
await self._update_batch(
batch_id,
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
)
if errors:
await self._update_batch(
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
)
return
try:
output_file_id = await self._create_output_file(batch_id, success_results, "success")
await self._update_batch(batch_id, output_file_id=output_file_id)
error_file_id = await self._create_output_file(batch_id, error_results, "error")
await self._update_batch(batch_id, error_file_id=error_file_id)
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
logger.info(
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
)
except Exception as e:
# note: errors is empty at this point, so we don't lose anything by ignoring it
await self._update_batch(
batch_id,
status="failed",
failed_at=int(time.time()),
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
)
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
"""Process a single request from the batch."""
request_id = f"batch_req_{batch_id}_{request.line_num}"
try:
# TODO(SECURITY): review body for security issues
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": chat_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {
"id": request_id,
"custom_id": request.custom_id,
"error": {"type": "request_failed", "message": str(e)},
}
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
"""
Create an output file with batch results.
This function filters results based on the specified file_type
and uploads the file to the Files API.
"""
output_lines = [json.dumps(result) for result in results]
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
return uploaded_file.id

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
class ReferenceBatchesImplConfig(BaseModel):
"""Configuration for the Reference Batches implementation."""
kvstore: KVStoreConfig = Field(
description="Configuration for the key-value store backend.",
)
max_concurrent_batches: int = Field(
default=1,
description="Maximum number of concurrent batches to process simultaneously.",
ge=1,
)
max_concurrent_requests_per_batch: int = Field(
default=10,
description="Maximum number of concurrent requests to process per batch.",
ge=1,
)
# TODO: add a max requests per second rate limiter
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="batches.db",
),
}

View file

@ -22,7 +22,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.models.llama.datatypes import Role
@ -72,30 +72,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
}
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
OpenAICategories.HATE: [CAT_HATE],
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
# These are custom categories that are not in the OpenAI moderation categories
"custom/defamation": [CAT_DEFAMATION],
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
"custom/privacy_violation": [CAT_PRIVACY],
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
"custom/elections": [CAT_ELECTIONS],
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES,
@ -424,9 +400,9 @@ class LlamaGuardShield:
ModerationObject with appropriate configuration
"""
# Set default values for safe case
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False)
category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
flagged = False
user_message = None
metadata = {}
@ -453,19 +429,15 @@ class LlamaGuardShield:
],
)
# Get OpenAI categories for the unsafe codes
openai_categories = []
for code in unsafe_code_list:
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
openai_categories.extend(
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
)
llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list]
# Update categories for unsafe content
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
category_scores = {
k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
}
category_applied_input_types = {
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
}
flagged = True
user_message = CANNED_RESPONSE_TEXT

View file

@ -18,6 +18,7 @@ from llama_stack.apis.safety import (
ShieldStore,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -64,6 +65,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
class PromptGuardShield:
def __init__(

View file

@ -33,6 +33,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -128,11 +129,12 @@ class FaissIndex(EmbeddingIndex):
# Save updated index
await self._save_index()
async def delete_chunk(self, chunk_id: str) -> None:
if chunk_id not in self.chunk_ids:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
if not set(chunk_ids).issubset(self.chunk_ids):
return
async with self.chunk_id_lock:
def remove_chunk(chunk_id: str):
index = self.chunk_ids.index(chunk_id)
self.index.remove_ids(np.array([index]))
@ -146,6 +148,10 @@ class FaissIndex(EmbeddingIndex):
self.chunk_by_index = new_chunk_by_index
self.chunk_ids.pop(index)
async with self.chunk_id_lock:
for chunk_id in chunk_ids:
remove_chunk(chunk_id)
await self._save_index()
async def query_vector(
@ -297,8 +303,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a faiss index"""
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a faiss index"""
faiss_index = self.cache[store_id].index
for chunk_id in chunk_ids:
await faiss_index.delete_chunk(chunk_id)
await faiss_index.delete_chunks(chunks_for_deletion)

View file

@ -31,6 +31,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIV
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
RERANKER_TYPE_WEIGHTED,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -426,34 +427,36 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the SQLite vector store."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
def _delete_chunk():
def _delete_chunks():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute("BEGIN TRANSACTION")
# Delete from metadata table
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
placeholders = ",".join("?" * len(chunk_ids))
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id IN ({placeholders})", chunk_ids)
# Delete from vector table
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
cur.execute(f"DELETE FROM {self.vector_table} WHERE id IN ({placeholders})", chunk_ids)
# Delete from FTS table
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
cur.execute(f"DELETE FROM {self.fts_table} WHERE id IN ({placeholders})", chunk_ids)
connection.commit()
except Exception as e:
connection.rollback()
logger.error(f"Error deleting chunk {chunk_id}: {e}")
logger.error(f"Error deleting chunks: {e}")
raise
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_chunk)
await asyncio.to_thread(_delete_chunks)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
@ -551,12 +554,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a sqlite_vec index."""
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a sqlite_vec index."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.batches,
provider_type="inline::reference",
pip_packages=["openai"],
module="llama_stack.providers.inline.batches.reference",
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
api_dependencies=[
Api.inference,
Api.files,
Api.models,
],
description="Reference implementation of batches API with KVStore persistence.",
),
]

View file

@ -342,6 +342,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -350,6 +351,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
@ -464,6 +466,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -731,6 +734,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,

View file

@ -235,6 +235,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
# TODO: tools are never added to the request, so we need to add them here
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
@ -378,6 +379,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
@ -431,4 +433,5 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
user=user,
)
logger.debug(f"fireworks params: {params}")
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -457,9 +457,6 @@ class OllamaInferenceAdapter(
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_obj = await self._get_model(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model {model} is not an embedding model")
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")

View file

@ -308,9 +308,7 @@ class TGIAdapter(_HfAdapter):
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(
model=config.url,
)
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]

View file

@ -26,6 +26,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -146,8 +147,10 @@ class ChromaIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunk(self, chunk_id: str) -> None:
await maybe_await(self.collection.delete([chunk_id]))
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a single chunk from the Chroma collection by its ID."""
ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion]
await maybe_await(self.collection.delete(ids=ids))
async def query_hybrid(
self,
@ -175,6 +178,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache = {}
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.files_api = files_api
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
@ -258,5 +262,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache[vector_db_id] = index
return index
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Chroma vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -28,6 +28,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_WEIGHTED,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -287,14 +288,17 @@ class MilvusIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the Milvus collection."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
try:
# Use IN clause with square brackets and single quotes for VARCHAR field
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
)
except Exception as e:
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise
@ -420,12 +424,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a milvus vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -163,10 +164,11 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the PostgreSQL table."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,))
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
@ -275,12 +277,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a PostgreSQL vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -29,6 +29,7 @@ from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig a
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -88,15 +89,16 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the Qdrant collection."""
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
try:
await self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
points_selector=models.PointIdsList(points=chunk_ids),
)
except Exception as e:
log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
raise
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
@ -264,12 +266,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) -> VectorStoreFileObject:
# Qdrant doesn't allow multiple clients to access the same storage path simultaneously.
async with self._qdrant_lock:
await super().openai_attach_file_to_vector_store(vector_store_id, file_id, attributes, chunking_strategy)
return await super().openai_attach_file_to_vector_store(
vector_store_id, file_id, attributes, chunking_strategy
)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Qdrant vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
for chunk_id in chunk_ids:
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
OpenAIVectorStoreMixin,
)
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -67,6 +68,7 @@ class WeaviateIndex(EmbeddingIndex):
data_objects.append(
wvc.data.DataObject(
properties={
"chunk_id": chunk.chunk_id,
"chunk_content": chunk.model_dump_json(),
},
vector=embeddings[i].tolist(),
@ -79,10 +81,11 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
collection = self.client.collections.get(sanitized_collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any([chunk_id]))
chunk_ids = [chunk.chunk_id for chunk in chunks_for_deletion]
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
@ -307,10 +310,10 @@ class WeaviateVectorIOAdapter(
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True)
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
if not index:
raise ValueError(f"Vector DB {sanitized_collection_name} not found")
await index.delete(chunk_ids)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -31,15 +31,15 @@ from openai.types.chat import (
from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
from openai.types.chat import (
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
@ -70,7 +70,7 @@ from openai.types.chat.chat_completion_chunk import (
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
from openai.types.chat.chat_completion_message_tool_call import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
@ -633,7 +633,7 @@ async def convert_message_to_openai_dict_new(
)
elif isinstance(message, CompletionMessage):
tool_calls = [
OpenAIChatCompletionMessageToolCall(
OpenAIChatCompletionMessageFunctionToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
@ -903,7 +903,7 @@ def _convert_openai_request_response_format(
def _convert_openai_tool_calls(
tool_calls: list[OpenAIChatCompletionMessageToolCall],
tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall],
) -> list[ToolCall]:
"""
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.

View file

@ -6,7 +6,6 @@
import asyncio
import json
import logging
import mimetypes
import time
import uuid
@ -37,10 +36,15 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
content_from_data_and_mime_type,
make_overlapped_chunks,
)
logger = logging.getLogger(__name__)
logger = get_logger(__name__, category="vector_io")
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5
@ -154,8 +158,8 @@ class OpenAIVectorStoreMixin(ABC):
self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a vector store."""
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a vector store."""
pass
@abstractmethod
@ -614,7 +618,7 @@ class OpenAIVectorStoreMixin(ABC):
)
vector_store_file_object.status = "completed"
except Exception as e:
logger.error(f"Error attaching file to vector store: {e}")
logger.exception("Error attaching file to vector store")
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
@ -767,7 +771,21 @@ class OpenAIVectorStoreMixin(ABC):
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks]
await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id])
# Create ChunkForDeletion objects with both chunk_id and document_id
chunks_for_deletion = []
for c in chunks:
if c.chunk_id:
document_id = c.metadata.get("document_id") or (
c.chunk_metadata.document_id if c.chunk_metadata else None
)
if document_id:
chunks_for_deletion.append(ChunkForDeletion(chunk_id=str(c.chunk_id), document_id=document_id))
else:
logger.warning(f"Chunk {c.chunk_id} has no document_id, skipping deletion")
if chunks_for_deletion:
await self.delete_chunks(vector_store_id, chunks_for_deletion)
store_info = self.openai_vector_stores[vector_store_id].copy()

View file

@ -16,6 +16,7 @@ from urllib.parse import unquote
import httpx
import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
URL,
@ -34,6 +35,18 @@ from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = logging.getLogger(__name__)
class ChunkForDeletion(BaseModel):
"""Information needed to delete a chunk from a vector store.
:param chunk_id: The ID of the chunk to delete
:param document_id: The ID of the document this chunk belongs to
"""
chunk_id: str
document_id: str
# Constants for reranker types
RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted"
@ -232,7 +245,7 @@ class EmbeddingIndex(ABC):
raise NotImplementedError()
@abstractmethod
async def delete_chunk(self, chunk_id: str):
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]):
raise NotImplementedError()
@abstractmethod