Merge branch 'main' into content-extension

This commit is contained in:
Francisco Arceo 2025-08-13 14:04:47 -06:00 committed by GitHub
commit 84a26339c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
73 changed files with 2416 additions and 506 deletions

View file

@ -706,6 +706,7 @@ class Agents(Protocol):
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a new OpenAI response.
@ -713,6 +714,7 @@ class Agents(Protocol):
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param include: (Optional) Additional fields to include in the response.
:returns: An OpenAIResponseObject.
"""
...

View file

@ -170,6 +170,23 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
type: Literal["web_search_call"] = "web_search_call"
class OpenAIResponseOutputMessageFileSearchToolCallResults(BaseModel):
"""Search results returned by the file search operation.
:param attributes: (Optional) Key-value attributes associated with the file
:param file_id: Unique identifier of the file containing the result
:param filename: Name of the file containing the result
:param score: Relevance score for this search result (between 0 and 1)
:param text: Text content of the search result
"""
attributes: dict[str, Any]
file_id: str
filename: str
score: float
text: str
@json_schema_type
class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
"""File search tool call output message for OpenAI responses.
@ -185,7 +202,7 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
queries: list[str]
status: str
type: Literal["file_search_call"] = "file_search_call"
results: list[dict[str, Any]] | None = None
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
@json_schema_type

View file

@ -62,3 +62,13 @@ class SessionNotFoundError(ValueError):
def __init__(self, session_name: str) -> None:
message = f"Session '{session_name}' not found or access denied."
super().__init__(message)
class ModelTypeError(TypeError):
"""raised when a model is present but not the correct type"""
def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None:
message = (
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
)
super().__init__(message)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from enum import Enum
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
@ -15,27 +15,6 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
# OpenAI Categories to return in the response
class OpenAICategories(StrEnum):
"""
Required set of categories in moderations api response
"""
VIOLENCE = "violence"
VIOLENCE_GRAPHIC = "violence/graphic"
HARRASMENT = "harassment"
HARRASMENT_THREATENING = "harassment/threatening"
HATE = "hate"
HATE_THREATENING = "hate/threatening"
ILLICIT = "illicit"
ILLICIT_VIOLENT = "illicit/violent"
SEXUAL = "sexual"
SEXUAL_MINORS = "sexual/minors"
SELF_HARM = "self-harm"
SELF_HARM_INTENT = "self-harm/intent"
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
@json_schema_type
class ModerationObjectResults(BaseModel):
"""A moderation object.
@ -43,20 +22,6 @@ class ModerationObjectResults(BaseModel):
:param categories: A list of the categories, and whether they are flagged or not.
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
:param category_scores: A list of the categories along with their scores as predicted by model.
Required set of categories that need to be in response
- violence
- violence/graphic
- harassment
- harassment/threatening
- hate
- hate/threatening
- illicit
- illicit/violent
- sexual
- sexual/minors
- self-harm
- self-harm/intent
- self-harm/instructions
"""
flagged: bool

View file

@ -91,7 +91,7 @@ def get_provider_dependencies(
def print_pip_install_help(config: BuildConfig):
normal_deps, special_deps = get_provider_dependencies(config)
normal_deps, special_deps, _ = get_provider_dependencies(config)
cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",

View file

@ -380,8 +380,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
json_content = json.dumps(convert_pydantic_to_json_value(result))
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
status_code = httpx.codes.OK
if options.method.upper() == "DELETE" and result is None:
status_code = httpx.codes.NO_CONTENT
if status_code == httpx.codes.NO_CONTENT:
json_content = ""
mock_response = httpx.Response(
status_code=httpx.codes.OK,
status_code=status_code,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",

View file

@ -18,7 +18,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="inference")
class InferenceRouter(Inference):
@ -177,6 +177,15 @@ class InferenceRouter(Inference):
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type != expected_model_type:
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model
async def chat_completion(
self,
model_id: str,
@ -195,11 +204,7 @@ class InferenceRouter(Inference):
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
model = await self._get_model(model_id, ModelType.llm)
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
@ -301,11 +306,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
model = await self._get_model(model_id, ModelType.llm)
provider = await self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
@ -355,11 +356,7 @@ class InferenceRouter(Inference):
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
await self._get_model(model_id, ModelType.embedding)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings(
model_id=model_id,
@ -395,12 +392,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
model_obj = await self._get_model(model, ModelType.llm)
params = dict(
model=model_obj.identifier,
prompt=prompt,
@ -476,11 +468,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
model_obj = await self._get_model(model, ModelType.llm)
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
@ -567,12 +555,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model '{model}' is not an embedding model")
model_obj = await self._get_model(model, ModelType.embedding)
params = dict(
model=model_obj.identifier,
input=input,
@ -871,4 +854,5 @@ class InferenceRouter(Inference):
model=model.identifier,
object="chat.completion",
)
logger.debug(f"InferenceRouter.completion_response: {final_response}")
await self.store.store_chat_completion(final_response, messages)

View file

@ -10,7 +10,7 @@ from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -82,20 +82,5 @@ class SafetyRouter(Safety):
input=input,
model=model,
)
self._validate_required_categories_exist(response)
return response
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
required_categories = list(map(str, OpenAICategories))
categories = response.results[0].categories
category_applied_input_types = response.results[0].category_applied_input_types
category_scores = response.results[0].category_scores
for i in [categories, category_applied_input_types, category_scores]:
if not set(required_categories).issubset(set(i.keys())):
raise ValueError(
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
)

View file

@ -63,6 +63,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def get_provider_impl(self, model_id: str) -> Any:
model = await lookup_model(self, model_id)
if model.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
return self.impls_by_provider_id[model.provider_id]
async def register_model(

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
@ -66,7 +66,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if model is None:
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
vector_db_data = {

View file

@ -21,10 +21,11 @@ from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Annotated, Any, get_origin
import httpx
import rich.pretty
import yaml
from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
@ -115,7 +116,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
if isinstance(exc, RequestValidationError):
return HTTPException(
status_code=400,
status_code=httpx.codes.BAD_REQUEST,
detail={
"errors": [
{
@ -128,20 +129,20 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
},
)
elif isinstance(exc, ValueError):
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
return HTTPException(status_code=400, detail=str(exc))
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
elif isinstance(exc, PermissionError | AccessDeniedError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}")
else:
return HTTPException(
status_code=500,
status_code=httpx.codes.INTERNAL_SERVER_ERROR,
detail="Internal server error: An unexpected error occurred.",
)
@ -180,7 +181,6 @@ async def sse_generator(event_gen_coroutine):
event_gen = await event_gen_coroutine
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("Generator cancelled")
if event_gen:
@ -236,6 +236,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
result = await maybe_await(value)
if isinstance(result, PaginatedResponse) and result.url is None:
result.url = route
if method.upper() == "DELETE" and result is None:
return Response(status_code=httpx.codes.NO_CONTENT)
return result
except Exception as e:
if logger.isEnabledFor(logging.DEBUG):
@ -352,7 +356,7 @@ class ClientVersionMiddleware:
await send(
{
"type": "http.response.start",
"status": 426,
"status": httpx.codes.UPGRADE_REQUIRED,
"headers": [[b"content-type", b"application/json"]],
}
)

View file

@ -16,6 +16,7 @@ from llama_stack.distributions.template import DistributionTemplate, RunConfigSe
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
def get_distribution_template() -> DistributionTemplate:
@ -71,9 +72,10 @@ def get_distribution_template() -> DistributionTemplate:
chromadb_provider = Provider(
provider_id="chromadb",
provider_type="remote::chromadb",
config={
"url": "${env.CHROMA_URL}",
},
config=ChromaVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}/",
url="${env.CHROMADB_URL:=}",
),
)
inference_model = ModelInput(

View file

@ -26,7 +26,10 @@ providers:
- provider_id: chromadb
provider_type: remote::chromadb
config:
url: ${env.CHROMA_URL}
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -22,7 +22,10 @@ providers:
- provider_id: chromadb
provider_type: remote::chromadb
config:
url: ${env.CHROMA_URL}
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -32,6 +32,7 @@ CATEGORIES = [
"tools",
"client",
"telemetry",
"openai_responses",
]
# Initialize category levels with default level

View file

@ -236,6 +236,7 @@ class ChatFormat:
arguments_json=json.dumps(tool_arguments),
)
)
content = ""
return RawMessage(
role="assistant",

View file

@ -327,10 +327,21 @@ class MetaReferenceAgentsImpl(Agents):
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters
input,
model,
instructions,
previous_response_id,
store,
stream,
temperature,
text,
tools,
include,
max_infer_iters,
)
async def list_openai_responses(

View file

@ -33,11 +33,16 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseOutputItemAdded,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFileSearchToolCallResults,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
@ -72,7 +77,9 @@ from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.inference.openai_compat import (
convert_tooldef_to_openai_tool,
)
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
logger = get_logger(name=__name__, category="openai_responses")
@ -81,7 +88,7 @@ OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _convert_response_content_to_chat_content(
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
@ -149,7 +156,9 @@ async def _convert_response_input_to_chat_messages(
return messages
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
async def _convert_chat_choice_to_response_message(
choice: OpenAIChoice,
) -> OpenAIResponseMessage:
"""
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
"""
@ -171,7 +180,9 @@ async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> Open
)
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam:
async def _convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:
"""
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
"""
@ -227,7 +238,9 @@ class OpenAIResponsesImpl:
self.vector_io_api = vector_io_api
async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None = None,
):
if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
@ -333,6 +346,7 @@ class OpenAIResponsesImpl:
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
):
stream = bool(stream)
@ -444,6 +458,8 @@ class OpenAIResponsesImpl:
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
# Track tool call items for streaming events
tool_call_item_ids: dict[int, str] = {}
async for chunk in completion_result:
chat_response_id = chunk.id
@ -470,24 +486,72 @@ class OpenAIResponsesImpl:
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
if tool_call.function.arguments:
# Guard against an initial None argument before we concatenate
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
else:
# Create new tool call entry if this is the first chunk for this index
is_new_tool_call = response_tool_call is None
if is_new_tool_call:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
chat_response_tool_calls[tool_call.index] = response_tool_call
# Create item ID for this tool call for streaming events
tool_call_item_id = f"fc_{uuid.uuid4()}"
tool_call_item_ids[tool_call.index] = tool_call_item_id
# Emit output_item.added event for the new function call
sequence_number += 1
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments="", # Will be filled incrementally via delta events
call_id=tool_call.id or "",
name=tool_call.function.name if tool_call.function else "",
id=tool_call_item_id,
status="in_progress",
)
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=response_id,
item=function_call_item,
output_index=len(output_messages),
sequence_number=sequence_number,
)
# Stream function call arguments as they arrive
if tool_call.function and tool_call.function.arguments:
tool_call_item_id = tool_call_item_ids[tool_call.index]
sequence_number += 1
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
# Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call:
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Emit function_call_arguments.done events for completed tool calls
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
sequence_number += 1
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(
arguments=final_arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=sequence_number,
)
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
# when there are tool calls, we need to clear the content
chat_response_content = []
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
@ -526,18 +590,56 @@ class OpenAIResponsesImpl:
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
# Emit output_item.done event for completed non-function tool call
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in tool_call_item_ids.items():
response_tool_call = chat_response_tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
if matching_item_id:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=response_id,
item=tool_call_log,
output_index=len(output_messages) - 1,
sequence_number=sequence_number,
)
if tool_response_message:
next_turn_messages.append(tool_response_message)
for tool_call in function_tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
# Find the item_id for this tool call from our tracking dictionary
matching_item_id = None
for index, item_id in tool_call_item_ids.items():
response_tool_call = chat_response_tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
# Use existing item_id or create new one if not found
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=final_item_id,
status="completed",
)
output_messages.append(function_call_item)
# Emit output_item.done event for completed function call
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=response_id,
item=function_call_item,
output_index=len(output_messages) - 1,
sequence_number=sequence_number,
)
if not function_tool_calls and not non_function_tool_calls:
@ -773,7 +875,8 @@ class OpenAIResponsesImpl:
)
elif function.name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
@ -792,7 +895,9 @@ class OpenAIResponsesImpl:
error_exc = e
if function.name in ctx.mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
@ -826,12 +931,13 @@ class OpenAIResponsesImpl:
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
message.results.append(
{
"file_id": doc_id,
"filename": doc_id,
"text": text,
"score": score,
}
OpenAIResponseOutputMessageFileSearchToolCallResults(
file_id=doc_id,
filename=doc_id,
text=text,
score=score,
attributes={},
)
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
@ -843,7 +949,10 @@ class OpenAIResponsesImpl:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
)
content = []
for item in result.content:

View file

@ -191,7 +191,11 @@ class AgentPersistence:
sessions = []
for value in values:
try:
session_info = Session(**json.loads(value))
data = json.loads(value)
if "turn_id" in data:
continue
session_info = Session(**data)
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")

View file

@ -22,7 +22,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.models.llama.datatypes import Role
@ -72,30 +72,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
}
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
OpenAICategories.HATE: [CAT_HATE],
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
# These are custom categories that are not in the OpenAI moderation categories
"custom/defamation": [CAT_DEFAMATION],
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
"custom/privacy_violation": [CAT_PRIVACY],
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
"custom/elections": [CAT_ELECTIONS],
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES,
@ -424,9 +400,9 @@ class LlamaGuardShield:
ModerationObject with appropriate configuration
"""
# Set default values for safe case
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False)
category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
flagged = False
user_message = None
metadata = {}
@ -453,19 +429,15 @@ class LlamaGuardShield:
],
)
# Get OpenAI categories for the unsafe codes
openai_categories = []
for code in unsafe_code_list:
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
openai_categories.extend(
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
)
llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list]
# Update categories for unsafe content
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
category_scores = {
k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
}
category_applied_input_types = {
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
}
flagged = True
user_message = CANNED_RESPONSE_TEXT

View file

@ -18,6 +18,7 @@ from llama_stack.apis.safety import (
ShieldStore,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -64,6 +65,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
class PromptGuardShield:
def __init__(

View file

@ -33,6 +33,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -128,11 +129,12 @@ class FaissIndex(EmbeddingIndex):
# Save updated index
await self._save_index()
async def delete_chunk(self, chunk_id: str) -> None:
if chunk_id not in self.chunk_ids:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
if not set(chunk_ids).issubset(self.chunk_ids):
return
async with self.chunk_id_lock:
def remove_chunk(chunk_id: str):
index = self.chunk_ids.index(chunk_id)
self.index.remove_ids(np.array([index]))
@ -146,6 +148,10 @@ class FaissIndex(EmbeddingIndex):
self.chunk_by_index = new_chunk_by_index
self.chunk_ids.pop(index)
async with self.chunk_id_lock:
for chunk_id in chunk_ids:
remove_chunk(chunk_id)
await self._save_index()
async def query_vector(
@ -297,8 +303,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a faiss index"""
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a faiss index"""
faiss_index = self.cache[store_id].index
for chunk_id in chunk_ids:
await faiss_index.delete_chunk(chunk_id)
await faiss_index.delete_chunks(chunks_for_deletion)

View file

@ -31,6 +31,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIV
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
RERANKER_TYPE_WEIGHTED,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -426,34 +427,36 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the SQLite vector store."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
def _delete_chunk():
def _delete_chunks():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute("BEGIN TRANSACTION")
# Delete from metadata table
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
placeholders = ",".join("?" * len(chunk_ids))
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id IN ({placeholders})", chunk_ids)
# Delete from vector table
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
cur.execute(f"DELETE FROM {self.vector_table} WHERE id IN ({placeholders})", chunk_ids)
# Delete from FTS table
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
cur.execute(f"DELETE FROM {self.fts_table} WHERE id IN ({placeholders})", chunk_ids)
connection.commit()
except Exception as e:
connection.rollback()
logger.error(f"Error deleting chunk {chunk_id}: {e}")
logger.error(f"Error deleting chunks: {e}")
raise
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_chunk)
await asyncio.to_thread(_delete_chunks)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
@ -551,12 +554,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a sqlite_vec index."""
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a sqlite_vec index."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -342,6 +342,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -350,6 +351,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
@ -464,6 +466,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -731,6 +734,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,

View file

@ -235,6 +235,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
# TODO: tools are never added to the request, so we need to add them here
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
@ -378,6 +379,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
@ -431,4 +433,5 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
user=user,
)
logger.debug(f"fireworks params: {params}")
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -457,9 +457,6 @@ class OllamaInferenceAdapter(
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_obj = await self._get_model(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model {model} is not an embedding model")
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")

View file

@ -308,9 +308,7 @@ class TGIAdapter(_HfAdapter):
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(
model=config.url,
)
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]

View file

@ -26,6 +26,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -115,8 +116,10 @@ class ChromaIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in Chroma")
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a single chunk from the Chroma collection by its ID."""
ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion]
await maybe_await(self.collection.delete(ids=ids))
async def query_hybrid(
self,
@ -144,6 +147,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache = {}
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.files_api = files_api
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
@ -227,5 +231,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache[vector_db_id] = index
return index
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Chroma vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -28,6 +28,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_WEIGHTED,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -287,14 +288,17 @@ class MilvusIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the Milvus collection."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
try:
# Use IN clause with square brackets and single quotes for VARCHAR field
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
)
except Exception as e:
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise
@ -420,12 +424,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a milvus vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -163,10 +164,11 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the PostgreSQL table."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,))
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
@ -275,12 +277,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a PostgreSQL vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -29,6 +29,7 @@ from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig a
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -88,15 +89,16 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the Qdrant collection."""
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
try:
await self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
points_selector=models.PointIdsList(points=chunk_ids),
)
except Exception as e:
log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
raise
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
@ -264,12 +266,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) -> VectorStoreFileObject:
# Qdrant doesn't allow multiple clients to access the same storage path simultaneously.
async with self._qdrant_lock:
await super().openai_attach_file_to_vector_store(vector_store_id, file_id, attributes, chunking_strategy)
return await super().openai_attach_file_to_vector_store(
vector_store_id, file_id, attributes, chunking_strategy
)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Qdrant vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
for chunk_id in chunk_ids:
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
OpenAIVectorStoreMixin,
)
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -67,6 +68,7 @@ class WeaviateIndex(EmbeddingIndex):
data_objects.append(
wvc.data.DataObject(
properties={
"chunk_id": chunk.chunk_id,
"chunk_content": chunk.model_dump_json(),
},
vector=embeddings[i].tolist(),
@ -79,10 +81,11 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
collection = self.client.collections.get(sanitized_collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any([chunk_id]))
chunk_ids = [chunk.chunk_id for chunk in chunks_for_deletion]
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
@ -307,10 +310,10 @@ class WeaviateVectorIOAdapter(
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True)
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
if not index:
raise ValueError(f"Vector DB {sanitized_collection_name} not found")
await index.delete(chunk_ids)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -70,7 +70,7 @@ from openai.types.chat.chat_completion_chunk import (
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
from openai.types.chat.chat_completion_message_tool_call import (
Function as OpenAIFunction,
)
from pydantic import BaseModel

View file

@ -6,7 +6,6 @@
import asyncio
import json
import logging
import mimetypes
import time
import uuid
@ -38,10 +37,15 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
content_from_data_and_mime_type,
make_overlapped_chunks,
)
logger = logging.getLogger(__name__)
logger = get_logger(__name__, category="vector_io")
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5
@ -155,8 +159,8 @@ class OpenAIVectorStoreMixin(ABC):
self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a vector store."""
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a vector store."""
pass
@abstractmethod
@ -652,7 +656,7 @@ class OpenAIVectorStoreMixin(ABC):
)
vector_store_file_object.status = "completed"
except Exception as e:
logger.error(f"Error attaching file to vector store: {e}")
logger.exception("Error attaching file to vector store")
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
@ -805,7 +809,21 @@ class OpenAIVectorStoreMixin(ABC):
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks]
await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id])
# Create ChunkForDeletion objects with both chunk_id and document_id
chunks_for_deletion = []
for c in chunks:
if c.chunk_id:
document_id = c.metadata.get("document_id") or (
c.chunk_metadata.document_id if c.chunk_metadata else None
)
if document_id:
chunks_for_deletion.append(ChunkForDeletion(chunk_id=str(c.chunk_id), document_id=document_id))
else:
logger.warning(f"Chunk {c.chunk_id} has no document_id, skipping deletion")
if chunks_for_deletion:
await self.delete_chunks(vector_store_id, chunks_for_deletion)
store_info = self.openai_vector_stores[vector_store_id].copy()

View file

@ -16,6 +16,7 @@ from urllib.parse import unquote
import httpx
import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
URL,
@ -34,6 +35,18 @@ from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = logging.getLogger(__name__)
class ChunkForDeletion(BaseModel):
"""Information needed to delete a chunk from a vector store.
:param chunk_id: The ID of the chunk to delete
:param document_id: The ID of the document this chunk belongs to
"""
chunk_id: str
document_id: str
# Constants for reranker types
RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted"
@ -232,7 +245,7 @@ class EmbeddingIndex(ABC):
raise NotImplementedError()
@abstractmethod
async def delete_chunk(self, chunk_id: str):
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]):
raise NotImplementedError()
@abstractmethod