mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-05 02:17:31 +00:00
Merge branch 'main' into chroma
This commit is contained in:
commit
470adfc2df
750 changed files with 243399 additions and 28283 deletions
|
|
@ -17,7 +17,7 @@ from llama_stack.apis.models import Model
|
|||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import ToolGroup
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -68,10 +68,10 @@ class ShieldsProtocolPrivate(Protocol):
|
|||
async def unregister_shield(self, identifier: str) -> None: ...
|
||||
|
||||
|
||||
class VectorDBsProtocolPrivate(Protocol):
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
|
||||
class VectorStoresProtocolPrivate(Protocol):
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None: ...
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None: ...
|
||||
|
||||
|
||||
class DatasetsProtocolPrivate(Protocol):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,12 @@ from llama_stack.core.datatypes import AccessRule, Api
|
|||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceAgentsImplConfig,
|
||||
deps: dict[Api, Any],
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
@ -23,7 +28,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
|||
deps[Api.tool_groups],
|
||||
deps[Api.conversations],
|
||||
policy,
|
||||
Api.telemetry in deps,
|
||||
telemetry_enabled,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from llama_stack.apis.agents import (
|
|||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrail
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
|
|
@ -82,8 +83,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.policy = policy
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
self.responses_store = ResponsesStore(self.config.responses_store, self.policy)
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence.agent_state)
|
||||
self.responses_store = ResponsesStore(self.config.persistence.responses, self.policy)
|
||||
await self.responses_store.initialize()
|
||||
self.openai_responses_impl = OpenAIResponsesImpl(
|
||||
inference_api=self.inference_api,
|
||||
|
|
@ -91,6 +92,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api=self.tool_runtime_api,
|
||||
responses_store=self.responses_store,
|
||||
vector_io_api=self.vector_io_api,
|
||||
safety_api=self.safety_api,
|
||||
conversations_api=self.conversations_api,
|
||||
)
|
||||
|
||||
|
|
@ -337,7 +339,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
|
|
@ -352,7 +354,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
shields,
|
||||
guardrails,
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
|
|
|
|||
|
|
@ -8,24 +8,30 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
|
||||
|
||||
class AgentPersistenceConfig(BaseModel):
|
||||
"""Nested persistence configuration for agents."""
|
||||
|
||||
agent_state: KVStoreReference
|
||||
responses: ResponsesStoreReference
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||
persistence_store: KVStoreConfig
|
||||
responses_store: SqlStoreConfig
|
||||
persistence: AgentPersistenceConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="agents_store.db",
|
||||
),
|
||||
"responses_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="responses_store.db",
|
||||
),
|
||||
"persistence": {
|
||||
"agent_state": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
).model_dump(exclude_none=True),
|
||||
"responses": ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from collections.abc import AsyncIterator
|
|||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
|
|
@ -34,6 +35,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -48,6 +50,7 @@ from .types import ChatCompletionContext, ToolContext
|
|||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
extract_guardrail_ids,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
|
@ -66,6 +69,7 @@ class OpenAIResponsesImpl:
|
|||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
safety_api: Safety,
|
||||
conversations_api: Conversations,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
|
|
@ -73,6 +77,7 @@ class OpenAIResponsesImpl:
|
|||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.safety_api = safety_api
|
||||
self.conversations_api = conversations_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
|
|
@ -100,6 +105,7 @@ class OpenAIResponsesImpl:
|
|||
input: str | list[OpenAIResponseInput],
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
previous_response_id: str | None,
|
||||
conversation: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
|
|
@ -124,16 +130,39 @@ class OpenAIResponsesImpl:
|
|||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
|
||||
tool_context.recover_tools_from_previous_response(previous_response)
|
||||
elif conversation is not None:
|
||||
conversation_items = await self.conversations_api.list(conversation, order="asc")
|
||||
|
||||
# Use stored messages as source of truth (like previous_response.messages)
|
||||
stored_messages = await self.responses_store.get_conversation_messages(conversation)
|
||||
|
||||
all_input = input
|
||||
if not conversation_items.data:
|
||||
# First turn - just convert the new input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
else:
|
||||
if not stored_messages:
|
||||
all_input = conversation_items.data
|
||||
if isinstance(input, str):
|
||||
all_input.append(
|
||||
OpenAIResponseMessage(
|
||||
role="user", content=[OpenAIResponseInputMessageContentText(text=input)]
|
||||
)
|
||||
)
|
||||
else:
|
||||
all_input.extend(input)
|
||||
else:
|
||||
all_input = input
|
||||
|
||||
messages = stored_messages or []
|
||||
new_messages = await convert_response_input_to_chat_messages(all_input, previous_messages=messages)
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
all_input = input
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
messages = await convert_response_input_to_chat_messages(all_input)
|
||||
|
||||
return all_input, messages, tool_context
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
|
|
@ -220,41 +249,34 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
guardrails: list[ResponseGuardrailSpec] | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
# Shields parameter received via extra_body - not yet implemented
|
||||
if shields is not None:
|
||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
|
||||
|
||||
if conversation is not None and previous_response_id is not None:
|
||||
raise ValueError(
|
||||
"Mutually exclusive parameters: 'previous_response_id' and 'conversation'. Ensure you are only providing one of these parameters."
|
||||
)
|
||||
|
||||
original_input = input # needed for syncing to Conversations
|
||||
if conversation is not None:
|
||||
if previous_response_id is not None:
|
||||
raise ValueError(
|
||||
"Mutually exclusive parameters: 'previous_response_id' and 'conversation'. Ensure you are only providing one of these parameters."
|
||||
)
|
||||
|
||||
if not conversation.startswith("conv_"):
|
||||
raise InvalidConversationIdError(conversation)
|
||||
|
||||
# Check conversation exists (raises ConversationNotFoundError if not)
|
||||
_ = await self.conversations_api.get_conversation(conversation)
|
||||
input = await self._load_conversation_context(conversation, input)
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
original_input=original_input,
|
||||
conversation=conversation,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
conversation=conversation,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
guardrail_ids=guardrail_ids,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -292,7 +314,6 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
original_input: str | list[OpenAIResponseInput] | None = None,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
conversation: str | None = None,
|
||||
|
|
@ -301,12 +322,15 @@ class OpenAIResponsesImpl:
|
|||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
input, tools, previous_response_id
|
||||
input, tools, previous_response_id, conversation
|
||||
)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
# Structured outputs
|
||||
response_format = await convert_response_text_to_chat_response_format(text)
|
||||
|
|
@ -333,11 +357,16 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
safety_api=self.safety_api,
|
||||
guardrail_ids=guardrail_ids,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
final_response = None
|
||||
failed_response = None
|
||||
|
||||
output_items = []
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
final_response = stream_chunk.response
|
||||
|
|
@ -345,102 +374,50 @@ class OpenAIResponsesImpl:
|
|||
failed_response = stream_chunk.response
|
||||
yield stream_chunk
|
||||
|
||||
if stream_chunk.type == "response.output_item.done":
|
||||
item = stream_chunk.item
|
||||
output_items.append(item)
|
||||
|
||||
# Store and sync immediately after yielding terminal events
|
||||
# This ensures the storage/syncing happens even if the consumer breaks early
|
||||
if (
|
||||
stream_chunk.type in {"response.completed", "response.incomplete"}
|
||||
and store
|
||||
and final_response
|
||||
and failed_response is None
|
||||
):
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=all_input,
|
||||
messages=orchestrator.final_messages,
|
||||
messages_to_store = list(
|
||||
filter(lambda x: not isinstance(x, OpenAISystemMessageParam), orchestrator.final_messages)
|
||||
)
|
||||
if store:
|
||||
# TODO: we really should work off of output_items instead of "final_messages"
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=all_input,
|
||||
messages=messages_to_store,
|
||||
)
|
||||
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"} and conversation and final_response:
|
||||
# for Conversations, we need to use the original_input if it's available, otherwise use input
|
||||
sync_input = original_input if original_input is not None else input
|
||||
await self._sync_response_to_conversation(conversation, sync_input, final_response)
|
||||
if conversation:
|
||||
await self._sync_response_to_conversation(conversation, input, output_items)
|
||||
await self.responses_store.store_conversation_messages(conversation, messages_to_store)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
return await self.responses_store.delete_response_object(response_id)
|
||||
|
||||
async def _load_conversation_context(
|
||||
self, conversation_id: str, content: str | list[OpenAIResponseInput]
|
||||
) -> list[OpenAIResponseInput]:
|
||||
"""Load conversation history and merge with provided content."""
|
||||
conversation_items = await self.conversations_api.list(conversation_id, order="asc")
|
||||
|
||||
context_messages = []
|
||||
for item in conversation_items.data:
|
||||
if isinstance(item, OpenAIResponseMessage):
|
||||
if item.role == "user":
|
||||
context_messages.append(
|
||||
OpenAIResponseMessage(
|
||||
role="user", content=item.content, id=item.id if hasattr(item, "id") else None
|
||||
)
|
||||
)
|
||||
elif item.role == "assistant":
|
||||
context_messages.append(
|
||||
OpenAIResponseMessage(
|
||||
role="assistant", content=item.content, id=item.id if hasattr(item, "id") else None
|
||||
)
|
||||
)
|
||||
|
||||
# add new content to context
|
||||
if isinstance(content, str):
|
||||
context_messages.append(OpenAIResponseMessage(role="user", content=content))
|
||||
elif isinstance(content, list):
|
||||
context_messages.extend(content)
|
||||
|
||||
return context_messages
|
||||
|
||||
async def _sync_response_to_conversation(
|
||||
self, conversation_id: str, content: str | list[OpenAIResponseInput], response: OpenAIResponseObject
|
||||
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
|
||||
) -> None:
|
||||
"""Sync content and response messages to the conversation."""
|
||||
conversation_items = []
|
||||
|
||||
# add user content message(s)
|
||||
if isinstance(content, str):
|
||||
if isinstance(input, str):
|
||||
conversation_items.append(
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
OpenAIResponseMessage(role="user", content=[OpenAIResponseInputMessageContentText(text=input)])
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if not isinstance(item, OpenAIResponseMessage):
|
||||
raise NotImplementedError(f"Unsupported input item type: {type(item)}")
|
||||
elif isinstance(input, list):
|
||||
conversation_items.extend(input)
|
||||
|
||||
if item.role == "user":
|
||||
if isinstance(item.content, str):
|
||||
conversation_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": item.content}],
|
||||
}
|
||||
)
|
||||
elif isinstance(item.content, list):
|
||||
conversation_items.append({"type": "message", "role": "user", "content": item.content})
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported user message content type: {type(item.content)}")
|
||||
elif item.role == "assistant":
|
||||
if isinstance(item.content, list):
|
||||
conversation_items.append({"type": "message", "role": "assistant", "content": item.content})
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported assistant message content type: {type(item.content)}")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message role: {item.role}")
|
||||
conversation_items.extend(output_items)
|
||||
|
||||
# add assistant response message
|
||||
for output_item in response.output:
|
||||
if isinstance(output_item, OpenAIResponseMessage) and output_item.role == "assistant":
|
||||
if hasattr(output_item, "content") and isinstance(output_item.content, list):
|
||||
conversation_items.append({"type": "message", "role": "assistant", "content": output_item.content})
|
||||
|
||||
if conversation_items:
|
||||
adapter = TypeAdapter(list[ConversationItem])
|
||||
validated_items = adapter.validate_python(conversation_items)
|
||||
await self.conversations_api.add_items(conversation_id, validated_items)
|
||||
adapter = TypeAdapter(list[ConversationItem])
|
||||
validated_items = adapter.validate_python(conversation_items)
|
||||
await self.conversations_api.add_items(conversation_id, validated_items)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
|
|
@ -42,8 +43,12 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStreamResponseRefusalDelta,
|
||||
OpenAIResponseObjectStreamResponseRefusalDone,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseUsage,
|
||||
OpenAIResponseUsageInputTokensDetails,
|
||||
|
|
@ -61,10 +66,15 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
from .utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
is_function_tool_call,
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
|
@ -100,6 +110,9 @@ class StreamingResponseOrchestrator:
|
|||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
instructions: str,
|
||||
safety_api,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
|
@ -108,6 +121,8 @@ class StreamingResponseOrchestrator:
|
|||
self.text = text
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
self.safety_api = safety_api
|
||||
self.guardrail_ids = guardrail_ids or []
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
|
|
@ -117,6 +132,25 @@ class StreamingResponseOrchestrator:
|
|||
self.citation_files: dict[str, str] = {}
|
||||
# Track accumulated usage across all inference calls
|
||||
self.accumulated_usage: OpenAIResponseUsage | None = None
|
||||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
# system message that is inserted into the model's context
|
||||
self.instructions = instructions
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
|
||||
|
||||
# Create a completed refusal response
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=self.response_id,
|
||||
created_at=self.created_at,
|
||||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
|
||||
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
|
||||
cloned: list[OpenAIResponseOutput] = []
|
||||
|
|
@ -145,6 +179,7 @@ class StreamingResponseOrchestrator:
|
|||
tools=self.ctx.available_tools(),
|
||||
error=error,
|
||||
usage=self.accumulated_usage,
|
||||
instructions=self.instructions,
|
||||
)
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
|
|
@ -161,6 +196,15 @@ class StreamingResponseOrchestrator:
|
|||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Input safety validation - check messages before processing
|
||||
if self.guardrail_ids:
|
||||
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
|
||||
input_violation_message = await run_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
||||
if input_violation_message:
|
||||
logger.info(f"Input guardrail violation: {input_violation_message}")
|
||||
yield await self._create_refusal_response(input_violation_message)
|
||||
return
|
||||
|
||||
async for stream_event in self._process_tools(output_messages):
|
||||
yield stream_event
|
||||
|
||||
|
|
@ -175,6 +219,7 @@ class StreamingResponseOrchestrator:
|
|||
# (some providers don't support non-empty response_format when tools are present)
|
||||
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
||||
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
|
|
@ -195,6 +240,11 @@ class StreamingResponseOrchestrator:
|
|||
completion_result_data = stream_event_or_result
|
||||
else:
|
||||
yield stream_event_or_result
|
||||
|
||||
# If violation detected, skip the rest of processing since we already sent refusal
|
||||
if self.violation_detected:
|
||||
return
|
||||
|
||||
if not completion_result_data:
|
||||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
last_completion_result = completion_result_data
|
||||
|
|
@ -500,6 +550,7 @@ class StreamingResponseOrchestrator:
|
|||
# Track tool call items for streaming events
|
||||
tool_call_item_ids: dict[int, str] = {}
|
||||
# Track content parts for streaming events
|
||||
message_item_added_emitted = False
|
||||
content_part_emitted = False
|
||||
reasoning_part_emitted = False
|
||||
refusal_part_emitted = False
|
||||
|
|
@ -518,9 +569,29 @@ class StreamingResponseOrchestrator:
|
|||
# Accumulate usage from chunks (typically in final chunk with stream_options)
|
||||
self._accumulate_chunk_usage(chunk)
|
||||
|
||||
# Track deltas for this specific chunk for guardrail validation
|
||||
chunk_events: list[OpenAIResponseObjectStream] = []
|
||||
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
# Emit output_item.added for the message on first content
|
||||
if not message_item_added_emitted:
|
||||
message_item_added_emitted = True
|
||||
self.sequence_number += 1
|
||||
message_item = OpenAIResponseMessage(
|
||||
id=message_item_id,
|
||||
content=[],
|
||||
role="assistant",
|
||||
status="in_progress",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=message_item,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit content_part.added event for first text chunk
|
||||
if not content_part_emitted:
|
||||
content_part_emitted = True
|
||||
|
|
@ -536,13 +607,19 @@ class StreamingResponseOrchestrator:
|
|||
sequence_number=self.sequence_number,
|
||||
)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
|
||||
text_delta_event = OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=content_index,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Buffer text delta events for guardrail check
|
||||
if self.guardrail_ids:
|
||||
chunk_events.append(text_delta_event)
|
||||
else:
|
||||
yield text_delta_event
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
|
|
@ -558,7 +635,11 @@ class StreamingResponseOrchestrator:
|
|||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
yield event
|
||||
# Buffer reasoning events for guardrail check
|
||||
if self.guardrail_ids:
|
||||
chunk_events.append(event)
|
||||
else:
|
||||
yield event
|
||||
reasoning_part_emitted = True
|
||||
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
|
||||
|
||||
|
|
@ -593,19 +674,22 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
# Emit output_item.added event for the new function call
|
||||
self.sequence_number += 1
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments="", # Will be filled incrementally via delta events
|
||||
call_id=tool_call.id or "",
|
||||
name=tool_call.function.name if tool_call.function else "",
|
||||
id=tool_call_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=function_call_item,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
|
||||
if not is_mcp_tool and tool_call.function.name not in ["web_search", "knowledge_search"]:
|
||||
# for MCP tools (and even other non-function tools) we emit an output message item later
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments="", # Will be filled incrementally via delta events
|
||||
call_id=tool_call.id or "",
|
||||
name=tool_call.function.name if tool_call.function else "",
|
||||
id=tool_call_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=function_call_item,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
|
|
@ -637,6 +721,22 @@ class StreamingResponseOrchestrator:
|
|||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Output Safety Validation for this chunk
|
||||
if self.guardrail_ids:
|
||||
# Check guardrails on accumulated text so far
|
||||
accumulated_text = "".join(chat_response_content)
|
||||
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||
if violation_message:
|
||||
logger.info(f"Output guardrail violation: {violation_message}")
|
||||
chunk_events.clear()
|
||||
yield await self._create_refusal_response(violation_message)
|
||||
self.violation_detected = True
|
||||
return
|
||||
else:
|
||||
# No violation detected, emit all content events for this chunk
|
||||
for event in chunk_events:
|
||||
yield event
|
||||
|
||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call = chat_response_tool_calls[tool_call_index]
|
||||
|
|
@ -700,6 +800,32 @@ class StreamingResponseOrchestrator:
|
|||
if chat_response_tool_calls:
|
||||
chat_response_content = []
|
||||
|
||||
# Emit output_item.done for message when we have content and no tool calls
|
||||
if message_item_added_emitted and not chat_response_tool_calls:
|
||||
content_parts = []
|
||||
if content_part_emitted:
|
||||
final_text = "".join(chat_response_content)
|
||||
content_parts.append(
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text=final_text,
|
||||
annotations=[],
|
||||
)
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
message_item = OpenAIResponseMessage(
|
||||
id=message_item_id,
|
||||
content=content_parts,
|
||||
role="assistant",
|
||||
status="completed",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=message_item,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
yield ChatCompletionResult(
|
||||
response_id=chat_response_id,
|
||||
content=chat_response_content,
|
||||
|
|
@ -760,6 +886,36 @@ class StreamingResponseOrchestrator:
|
|||
if not matching_item_id:
|
||||
matching_item_id = f"tc_{uuid.uuid4()}"
|
||||
|
||||
self.sequence_number += 1
|
||||
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
|
||||
item = OpenAIResponseOutputMessageMCPCall(
|
||||
arguments="",
|
||||
name=tool_call.function.name,
|
||||
id=matching_item_id,
|
||||
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
|
||||
status="in_progress",
|
||||
)
|
||||
elif tool_call.function.name == "web_search":
|
||||
item = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=matching_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
elif tool_call.function.name == "knowledge_search":
|
||||
item = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=matching_item_id,
|
||||
status="in_progress",
|
||||
queries=[tool_call.function.arguments or ""],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool call: {tool_call.function.name}")
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=item,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Execute tool call with streaming
|
||||
tool_call_log = None
|
||||
tool_response_message = None
|
||||
|
|
@ -1018,7 +1174,11 @@ class StreamingResponseOrchestrator:
|
|||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
item=OpenAIResponseOutputMessageMCPListTools(
|
||||
id=mcp_list_message.id,
|
||||
server_label=mcp_list_message.server_label,
|
||||
tools=[],
|
||||
),
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class ToolExecutor:
|
|||
|
||||
# Build result messages from tool execution
|
||||
output_message, input_message = await self._build_result_messages(
|
||||
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||
function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||
)
|
||||
|
||||
# Yield the final result
|
||||
|
|
@ -356,6 +356,7 @@ class ToolExecutor:
|
|||
self,
|
||||
function,
|
||||
tool_call_id: str,
|
||||
item_id: str,
|
||||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
error_exc: Exception | None,
|
||||
|
|
@ -375,7 +376,7 @@ class ToolExecutor:
|
|||
)
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=tool_call_id,
|
||||
id=item_id,
|
||||
arguments=function.arguments,
|
||||
name=function.name,
|
||||
server_label=mcp_tool_to_server[function.name].server_label,
|
||||
|
|
@ -389,14 +390,14 @@ class ToolExecutor:
|
|||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
id=item_id,
|
||||
status="completed",
|
||||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=tool_call_id,
|
||||
id=item_id,
|
||||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInput,
|
||||
|
|
@ -45,6 +47,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
|
|
@ -240,7 +243,8 @@ async def convert_response_text_to_chat_response_format(
|
|||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def get_message_type_by_role(role: str):
|
||||
async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
|
||||
"""Get the appropriate OpenAI message parameter type for a given role."""
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
|
|
@ -307,3 +311,55 @@ def is_function_tool_call(
|
|||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
"""Run guardrails against messages and return violation message if blocked."""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
shields_list = await safety_api.routing_table.list_shields()
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
if matching_shields:
|
||||
model_id = matching_shields[0].provider_resource_id
|
||||
model_ids.append(model_id)
|
||||
else:
|
||||
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
||||
|
||||
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
||||
responses = await asyncio.gather(*guardrail_tasks)
|
||||
|
||||
for response in responses:
|
||||
for result in response.results:
|
||||
if result.flagged:
|
||||
message = result.user_message or "Content blocked by safety guardrails"
|
||||
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
|
||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||
|
||||
if flagged_categories:
|
||||
message += f" (flagged for: {', '.join(flagged_categories)})"
|
||||
if violation_type:
|
||||
message += f" (violation type: {', '.join(violation_type)})"
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
|
||||
if not guardrails:
|
||||
return []
|
||||
|
||||
guardrail_ids = []
|
||||
for guardrail in guardrails:
|
||||
if isinstance(guardrail, str):
|
||||
guardrail_ids.append(guardrail)
|
||||
elif isinstance(guardrail, ResponseGuardrailSpec):
|
||||
guardrail_ids.append(guardrail.type)
|
||||
else:
|
||||
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
|
||||
|
||||
return guardrail_ids
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class ReferenceBatchesImplConfig(BaseModel):
|
||||
"""Configuration for the Reference Batches implementation."""
|
||||
|
||||
kvstore: KVStoreConfig = Field(
|
||||
kvstore: KVStoreReference = Field(
|
||||
description="Configuration for the key-value store backend.",
|
||||
)
|
||||
|
||||
|
|
@ -33,8 +33,8 @@ class ReferenceBatchesImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="batches.db",
|
||||
),
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="batches",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,20 +7,17 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class LocalFSDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig
|
||||
kvstore: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="localfs_datasetio.db",
|
||||
)
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="datasetio::localfs",
|
||||
).model_dump(exclude_none=True)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,20 +7,17 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel):
|
||||
kvstore: KVStoreConfig
|
||||
kvstore: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="meta_reference_eval.db",
|
||||
)
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="eval",
|
||||
).model_dump(exclude_none=True)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||
|
||||
|
||||
class LocalfsFilesImplConfig(BaseModel):
|
||||
storage_dir: str = Field(
|
||||
description="Directory to store uploaded files",
|
||||
)
|
||||
metadata_store: SqlStoreConfig = Field(
|
||||
metadata_store: SqlStoreReference = Field(
|
||||
description="SQL store configuration for file metadata",
|
||||
)
|
||||
ttl_secs: int = 365 * 24 * 60 * 60 # 1 year
|
||||
|
|
@ -24,8 +24,8 @@ class LocalfsFilesImplConfig(BaseModel):
|
|||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"storage_dir": "${env.FILES_STORAGE_DIR:=" + __distro_dir__ + "/files}",
|
||||
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="files_metadata.db",
|
||||
),
|
||||
"metadata_store": SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="files_metadata",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,13 +9,10 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
|
||||
class TelemetrySink(StrEnum):
|
||||
OTEL_TRACE = "otel_trace"
|
||||
OTEL_METRIC = "otel_metric"
|
||||
SQLITE = "sqlite"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
|
|
@ -30,12 +27,8 @@ class TelemetryConfig(BaseModel):
|
|||
description="The service name to use for telemetry",
|
||||
)
|
||||
sinks: list[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.SQLITE],
|
||||
description="List of telemetry sinks to enable (possible values: otel_trace, otel_metric, sqlite, console)",
|
||||
)
|
||||
sqlite_db_path: str = Field(
|
||||
default_factory=lambda: (RUNTIME_BASE_DIR / "trace_store.db").as_posix(),
|
||||
description="The path to the SQLite database to use for storing traces",
|
||||
default_factory=list,
|
||||
description="List of telemetry sinks to enable (possible values: otel_trace, otel_metric, console)",
|
||||
)
|
||||
|
||||
@field_validator("sinks", mode="before")
|
||||
|
|
@ -43,13 +36,12 @@ class TelemetryConfig(BaseModel):
|
|||
def validate_sinks(cls, v):
|
||||
if isinstance(v, str):
|
||||
return [TelemetrySink(sink.strip()) for sink in v.split(",")]
|
||||
return v
|
||||
return v or []
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:=\u200b}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:=sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
"sinks": "${env.TELEMETRY_SINKS:=}",
|
||||
"otel_exporter_otlp_endpoint": "${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,75 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
from opentelemetry.trace.status import StatusCode
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name="console_span_processor", category="telemetry")
|
||||
|
||||
|
||||
class ConsoleSpanProcessor(SpanProcessor):
|
||||
def __init__(self, print_attributes: bool = False):
|
||||
self.print_attributes = print_attributes
|
||||
|
||||
def on_start(self, span: ReadableSpan, parent_context=None) -> None:
|
||||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
|
||||
if span.status.status_code == StatusCode.ERROR:
|
||||
span_context += " [bold red][ERROR][/bold red]"
|
||||
elif span.status.status_code != StatusCode.UNSET:
|
||||
span_context += f" [{span.status.status_code}]"
|
||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||
span_context += f" ({duration_ms:.2f}ms)"
|
||||
logger.info(span_context)
|
||||
|
||||
if self.print_attributes and span.attributes:
|
||||
for key, value in span.attributes.items():
|
||||
if key.startswith("__"):
|
||||
continue
|
||||
str_value = str(value)
|
||||
if len(str_value) > 1000:
|
||||
str_value = str_value[:997] + "..."
|
||||
logger.info(f" [dim]{key}[/dim]: {str_value}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, dict) or isinstance(message, list):
|
||||
message = json.dumps(message, indent=2)
|
||||
severity_color = {
|
||||
"error": "red",
|
||||
"warn": "yellow",
|
||||
"info": "white",
|
||||
"debug": "dim",
|
||||
}.get(severity, "white")
|
||||
logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
|
||||
if event.attributes:
|
||||
for key, value in event.attributes.items():
|
||||
if key.startswith("__") or key in ["message", "severity"]:
|
||||
continue
|
||||
logger.info(f"[dim]{key}[/dim]: {value}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the processor."""
|
||||
pass
|
||||
|
||||
def force_flush(self, timeout_millis: float | None = None) -> bool:
|
||||
"""Force flush any pending spans."""
|
||||
return True
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.span import format_span_id, format_trace_id
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import LOCAL_ROOT_SPAN_MARKER
|
||||
|
||||
|
||||
class SQLiteSpanProcessor(SpanProcessor):
|
||||
def __init__(self, conn_string):
|
||||
"""Initialize the SQLite span processor with a connection string."""
|
||||
self.conn_string = conn_string
|
||||
self._local = threading.local() # Thread-local storage for connections
|
||||
self.setup_database()
|
||||
|
||||
def _get_connection(self):
|
||||
"""Get a thread-local database connection."""
|
||||
if not hasattr(self._local, "conn"):
|
||||
try:
|
||||
self._local.conn = sqlite3.connect(self.conn_string)
|
||||
except Exception as e:
|
||||
print(f"Error connecting to SQLite database: {e}")
|
||||
raise
|
||||
return self._local.conn
|
||||
|
||||
def setup_database(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(self.conn_string), exist_ok=True)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS traces (
|
||||
trace_id TEXT PRIMARY KEY,
|
||||
service_name TEXT,
|
||||
root_span_id TEXT,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS spans (
|
||||
span_id TEXT PRIMARY KEY,
|
||||
trace_id TEXT REFERENCES traces(trace_id),
|
||||
parent_span_id TEXT,
|
||||
name TEXT,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
attributes TEXT,
|
||||
status TEXT,
|
||||
kind TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS span_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
span_id TEXT REFERENCES spans(span_id),
|
||||
name TEXT,
|
||||
timestamp TIMESTAMP,
|
||||
attributes TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_traces_created_at
|
||||
ON traces(created_at)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
def on_start(self, span: Span, parent_context=None):
|
||||
"""Called when a span starts."""
|
||||
pass
|
||||
|
||||
def on_end(self, span: Span):
|
||||
"""Called when a span ends. Export the span data to SQLite."""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
trace_id = format_trace_id(span.get_span_context().trace_id)
|
||||
span_id = format_span_id(span.get_span_context().span_id)
|
||||
service_name = span.resource.attributes.get("service.name", "unknown")
|
||||
|
||||
parent_span_id = None
|
||||
parent_context = span.parent
|
||||
if parent_context:
|
||||
parent_span_id = format_span_id(parent_context.span_id)
|
||||
|
||||
# Insert into traces
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO traces (
|
||||
trace_id, service_name, root_span_id, start_time, end_time
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(trace_id) DO UPDATE SET
|
||||
root_span_id = COALESCE(root_span_id, excluded.root_span_id),
|
||||
start_time = MIN(excluded.start_time, start_time),
|
||||
end_time = MAX(excluded.end_time, end_time)
|
||||
""",
|
||||
(
|
||||
trace_id,
|
||||
service_name,
|
||||
(span_id if span.attributes.get(LOCAL_ROOT_SPAN_MARKER) else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, UTC).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, UTC).isoformat(),
|
||||
),
|
||||
)
|
||||
|
||||
# Insert into spans
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO spans (
|
||||
span_id, trace_id, parent_span_id, name,
|
||||
start_time, end_time, attributes, status,
|
||||
kind
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
span_id,
|
||||
trace_id,
|
||||
parent_span_id,
|
||||
span.name,
|
||||
datetime.fromtimestamp(span.start_time / 1e9, UTC).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, UTC).isoformat(),
|
||||
json.dumps(dict(span.attributes)),
|
||||
span.status.status_code.name,
|
||||
span.kind.name,
|
||||
),
|
||||
)
|
||||
|
||||
for event in span.events:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO span_events (
|
||||
span_id, name, timestamp, attributes
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
span_id,
|
||||
event.name,
|
||||
datetime.fromtimestamp(event.timestamp / 1e9, UTC).isoformat(),
|
||||
json.dumps(dict(event.attributes)),
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
print(f"Error exporting span to SQLite: {e}")
|
||||
|
||||
def shutdown(self):
|
||||
"""Cleanup any resources."""
|
||||
# We can't access other threads' connections, so we just close our own
|
||||
if hasattr(self._local, "conn"):
|
||||
try:
|
||||
self._local.conn.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing SQLite connection: {e}")
|
||||
finally:
|
||||
del self._local.conn
|
||||
|
||||
def force_flush(self, timeout_millis=30000):
|
||||
"""Force export of spans."""
|
||||
pass
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -13,43 +13,25 @@ from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExp
|
|||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
MetricEvent,
|
||||
MetricLabelMatcher,
|
||||
MetricQueryType,
|
||||
QueryCondition,
|
||||
QueryMetricsResponse,
|
||||
QuerySpanTreeResponse,
|
||||
QueryTracesResponse,
|
||||
Span,
|
||||
SpanEndPayload,
|
||||
SpanStartPayload,
|
||||
SpanStatus,
|
||||
StructuredLogEvent,
|
||||
Telemetry,
|
||||
Trace,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
||||
SQLiteSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
from .config import TelemetryConfig
|
||||
|
||||
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||
"active_spans": {},
|
||||
|
|
@ -68,66 +50,49 @@ def is_tracing_enabled(tracer):
|
|||
return span.is_recording()
|
||||
|
||||
|
||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
|
||||
self.config = config
|
||||
class TelemetryAdapter(Telemetry):
|
||||
def __init__(self, _config: TelemetryConfig, deps: dict[Api, Any]) -> None:
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||
}
|
||||
)
|
||||
|
||||
global _TRACER_PROVIDER
|
||||
# Initialize the correct span processor based on the provider state.
|
||||
# This is needed since once the span processor is set, it cannot be unset.
|
||||
# Recreating the telemetry adapter multiple times will result in duplicate span processors.
|
||||
# Since the library client can be recreated multiple times in a notebook,
|
||||
# the kernel will hold on to the span processor and cause duplicate spans to be written.
|
||||
if _TRACER_PROVIDER is None:
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
|
||||
if _TRACER_PROVIDER is None:
|
||||
provider = TracerProvider()
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
|
||||
# Use single OTLP endpoint for all telemetry signals
|
||||
if TelemetrySink.OTEL_TRACE in self.config.sinks or TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
if self.config.otel_exporter_otlp_endpoint is None:
|
||||
raise ValueError(
|
||||
"otel_exporter_otlp_endpoint is required when OTEL_TRACE or OTEL_METRIC is enabled"
|
||||
)
|
||||
# Use single OTLP endpoint for all telemetry signals
|
||||
|
||||
# Let OpenTelemetry SDK handle endpoint construction automatically
|
||||
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
|
||||
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
|
||||
if TelemetrySink.OTEL_TRACE in self.config.sinks:
|
||||
span_exporter = OTLPSpanExporter()
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
span_exporter = OTLPSpanExporter()
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
|
||||
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(SQLiteSpanProcessor(self.config.sqlite_db_path))
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor(print_attributes=True))
|
||||
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
|
||||
metric_provider = MeterProvider(metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.is_otel_endpoint_set = True
|
||||
else:
|
||||
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry")
|
||||
self.is_otel_endpoint_set = False
|
||||
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
trace.get_tracer_provider().force_flush()
|
||||
if self.is_otel_endpoint_set:
|
||||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
|
|
@ -139,47 +104,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
else:
|
||||
raise ValueError(f"Unknown event type: {event}")
|
||||
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_name: str,
|
||||
start_time: int,
|
||||
end_time: int | None = None,
|
||||
granularity: str | None = None,
|
||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||
) -> QueryMetricsResponse:
|
||||
"""Query metrics from the telemetry store.
|
||||
|
||||
Args:
|
||||
metric_name: The name of the metric to query (e.g., "prompt_tokens")
|
||||
start_time: Start time as Unix timestamp
|
||||
end_time: End time as Unix timestamp (defaults to now if None)
|
||||
granularity: Time granularity for aggregation
|
||||
query_type: Type of query (RANGE or INSTANT)
|
||||
label_matchers: Label filters to apply
|
||||
|
||||
Returns:
|
||||
QueryMetricsResponse with metric time series data
|
||||
"""
|
||||
# Convert timestamps to datetime objects
|
||||
start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC)
|
||||
end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None
|
||||
|
||||
# Use SQLite trace store if available
|
||||
if hasattr(self, "trace_store") and self.trace_store:
|
||||
return await self.trace_store.query_metrics(
|
||||
metric_name=metric_name,
|
||||
start_time=start_dt,
|
||||
end_time=end_dt,
|
||||
granularity=granularity,
|
||||
query_type=query_type,
|
||||
label_matchers=label_matchers,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks"
|
||||
)
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
|
|
@ -326,39 +250,3 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||
else:
|
||||
raise ValueError(f"Unknown structured log event: {event}")
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> QueryTracesResponse:
|
||||
return QueryTracesResponse(
|
||||
data=await self.trace_store.query_traces(
|
||||
attribute_filters=attribute_filters,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
)
|
||||
)
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
return await self.trace_store.get_trace(trace_id)
|
||||
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
return await self.trace_store.get_span(trace_id, span_id)
|
||||
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpanTreeResponse:
|
||||
return QuerySpanTreeResponse(
|
||||
data=await self.trace_store.get_span_tree(
|
||||
span_id=span_id,
|
||||
attributes_to_return=attributes_to_return,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
return RAGQueryResult(
|
||||
content=picked,
|
||||
metadata={
|
||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||
"document_ids": [c.document_id for c in chunks[: len(picked)]],
|
||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||
"scores": scores[: len(picked)],
|
||||
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
|
||||
|
|
|
|||
|
|
@ -12,9 +12,7 @@ from .config import ChromaVectorIOConfig
|
|||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaVectorIOAdapter
|
||||
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChromaVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
persistence: KVStoreReference = Field(description="Config for KV store backend")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
|
|
@ -23,8 +23,8 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": db_path,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="chroma_inline_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::chroma",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,6 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
|||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,22 +8,19 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FaissVectorIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig
|
||||
persistence: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="faiss_store.db",
|
||||
)
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::faiss",
|
||||
).model_dump(exclude_none=True)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,33 +17,21 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
|
||||
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
||||
|
|
@ -154,12 +142,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
chunks = []
|
||||
scores = []
|
||||
|
|
@ -174,12 +157,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
|
||||
)
|
||||
|
|
@ -198,28 +176,28 @@ class FaissIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
# Load existing banks from kvstore
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
|
||||
for vector_store_data in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store,
|
||||
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
|
@ -244,45 +222,33 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=vector_db.model_dump_json(),
|
||||
)
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
||||
|
||||
# Store in cache
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [i.vector_db for i in self.cache.values()]
|
||||
async def list_vector_stores(self) -> list[VectorStore]:
|
||||
return [i.vector_store for i in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
if vector_store_id not in self.cache:
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
|
||||
|
|
@ -290,10 +256,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,6 @@ from .config import MilvusVectorIOConfig
|
|||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,25 +8,22 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
persistence: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.MILVUS_DB_PATH:=" + __distro_dir__ + "}/" + "milvus.db",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::milvus",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
|
|||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
files_api = deps.get(Api.files)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -9,23 +9,21 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QdrantVectorIOConfig(BaseModel):
|
||||
path: str
|
||||
kvstore: KVStoreConfig
|
||||
persistence: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__, db_name="qdrant_registry.db"
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::qdrant",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,6 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
|||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,22 +8,19 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class SQLiteVectorIOConfig(BaseModel):
|
||||
db_path: str = Field(description="Path to the SQLite database file")
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
persistence: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="sqlite_vec_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::sqlite_vec",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,14 +17,10 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
|
@ -32,7 +28,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
RERANKER_TYPE_RRF,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
|
||||
|
||||
|
|
@ -45,7 +41,7 @@ HYBRID_SEARCH = "hybrid"
|
|||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:sqlite_vec:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::"
|
||||
|
|
@ -174,32 +170,18 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
# Insert vector embeddings
|
||||
embedding_data = [
|
||||
(
|
||||
(
|
||||
chunk.chunk_id,
|
||||
serialize_vector(emb.tolist()),
|
||||
)
|
||||
)
|
||||
((chunk.chunk_id, serialize_vector(emb.tolist())))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
cur.executemany(
|
||||
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
|
||||
embedding_data,
|
||||
)
|
||||
cur.executemany(f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
|
||||
# Insert FTS content
|
||||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||
cur.executemany(
|
||||
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
|
||||
[(row[0],) for row in fts_data],
|
||||
)
|
||||
cur.executemany(f"DELETE FROM [{self.fts_table}] WHERE id = ?;", [(row[0],) for row in fts_data])
|
||||
|
||||
# INSERT new entries
|
||||
cur.executemany(
|
||||
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
|
||||
fts_data,
|
||||
)
|
||||
cur.executemany(f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", fts_data)
|
||||
|
||||
connection.commit()
|
||||
|
||||
|
|
@ -215,12 +197,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Run batch insertion in a background thread
|
||||
await asyncio.to_thread(_execute_all_batch_inserts)
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector-based search using a virtual table for vector similarity.
|
||||
"""
|
||||
|
|
@ -260,12 +237,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||
"""
|
||||
|
|
@ -402,33 +374,32 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
await asyncio.to_thread(_delete_chunks)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
"""
|
||||
A VectorIO implementation using SQLite + sqlite_vec.
|
||||
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
|
||||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
This class handles vector database registration (with metadata stored in a table named `vector_stores`)
|
||||
and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
self.vector_store_table = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for db_json in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(db_json)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for db_json in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(db_json)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
|
@ -437,67 +408,64 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
async def list_vector_stores(self) -> list[VectorStore]:
|
||||
return [v.vector_store for v in self.cache.values()]
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=SQLiteVecIndex(
|
||||
dimension=vector_db.embedding_dimension,
|
||||
dimension=vector_store.embedding_dimension,
|
||||
db_path=self.config.db_path,
|
||||
bank_id=vector_db.identifier,
|
||||
bank_id=vector_store.identifier,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id not in self.cache:
|
||||
return
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# The VectorStoreWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# and then call our index's add_chunks.
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a sqlite_vec index."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -36,9 +36,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
Api.tool_groups,
|
||||
Api.conversations,
|
||||
],
|
||||
optional_api_dependencies=[
|
||||
Api.telemetry,
|
||||
],
|
||||
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.telemetry,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
],
|
||||
optional_api_dependencies=[Api.datasetio],
|
||||
module="llama_stack.providers.inline.telemetry.meta_reference",
|
||||
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
|
||||
description="Meta's reference implementation of telemetry and observability using OpenTelemetry.",
|
||||
),
|
||||
]
|
||||
|
|
@ -26,7 +26,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="Meta's reference implementation of a vector database.",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
|
|
@ -36,7 +36,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
|
|
@ -89,7 +89,7 @@ more details about Faiss in general.
|
|||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within an SQLite database.
|
||||
|
|
@ -297,7 +297,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
|||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
Please refer to the sqlite-vec provider documentation.
|
||||
""",
|
||||
|
|
@ -310,7 +310,7 @@ Please refer to the sqlite-vec provider documentation.
|
|||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
|
|
@ -352,7 +352,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
module="llama_stack.providers.inline.vector_io.chroma",
|
||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
|
|
@ -396,7 +396,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
|
|
@ -508,7 +508,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||
It allows you to store and query vectors directly within a Weaviate database.
|
||||
|
|
@ -548,7 +548,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
|
|||
module="llama_stack.providers.inline.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description=r"""
|
||||
[Qdrant](https://qdrant.tech/documentation/) is an inline and remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
|
|
@ -601,7 +601,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
|
|||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
Please refer to the inline provider documentation.
|
||||
""",
|
||||
|
|
@ -614,7 +614,7 @@ Please refer to the inline provider documentation.
|
|||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within a Milvus database.
|
||||
|
|
@ -820,7 +820,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
Please refer to the remote provider documentation.
|
||||
""",
|
||||
|
|
|
|||
|
|
@ -7,20 +7,17 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
|
||||
|
||||
class HuggingfaceDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig
|
||||
kvstore: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="huggingface_datasetio.db",
|
||||
)
|
||||
"kvstore": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="datasetio::huggingface",
|
||||
).model_dump(exclude_none=True)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ This provider enables dataset management using NVIDIA's NeMo Customizer service.
|
|||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||
|
||||
|
||||
class S3FilesImplConfig(BaseModel):
|
||||
|
|
@ -24,7 +24,7 @@ class S3FilesImplConfig(BaseModel):
|
|||
auto_create_bucket: bool = Field(
|
||||
default=False, description="Automatically create the S3 bucket if it doesn't exist"
|
||||
)
|
||||
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
|
||||
metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
|
|
@ -35,8 +35,8 @@ class S3FilesImplConfig(BaseModel):
|
|||
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
|
||||
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
|
||||
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
|
||||
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="s3_files_metadata.db",
|
||||
),
|
||||
"metadata_store": SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="s3_files_metadata",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
|
@ -14,8 +22,61 @@ class GeminiInferenceAdapter(OpenAIMixin):
|
|||
|
||||
provider_data_api_key_field: str = "gemini_api_key"
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
"models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
"models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048},
|
||||
}
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Override embeddings method to handle Gemini's missing usage statistics.
|
||||
Gemini's embedding API doesn't return usage information, so we provide default values.
|
||||
"""
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"model": await self._get_provider_model_id(params.model),
|
||||
"input": params.input,
|
||||
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
||||
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
||||
"user": params.user if params.user is not None else NOT_GIVEN,
|
||||
}
|
||||
|
||||
# Add extra_body if present
|
||||
extra_body = params.model_extra
|
||||
if extra_body:
|
||||
request_params["extra_body"] = extra_body
|
||||
|
||||
# Call OpenAI embeddings API with properly typed parameters
|
||||
response = await self.client.embeddings.create(**request_params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
# Gemini doesn't return usage statistics - use default values
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
else:
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=params.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ This provider enables running inference using NVIDIA NIM.
|
|||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
|
@ -45,7 +45,7 @@ The following example shows how to create a chat completion for an NVIDIA NIM.
|
|||
|
||||
```python
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model="nvidia/meta/llama-3.1-8b-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -67,37 +67,40 @@ print(f"Response: {response.choices[0].message.content}")
|
|||
The following example shows how to do tool calling for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
|
||||
tool_definition = ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get current weather information for a location",
|
||||
parameters={
|
||||
"location": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The city and state, e.g. San Francisco, CA",
|
||||
required=True,
|
||||
),
|
||||
"unit": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="Temperature unit (celsius or fahrenheit)",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
tool_definition = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather information for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "Temperature unit (celsius or fahrenheit)",
|
||||
"default": "celsius",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
tool_response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model="nvidia/meta/llama-3.1-8b-instruct",
|
||||
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||
tools=[tool_definition],
|
||||
)
|
||||
|
||||
print(f"Tool Response: {tool_response.choices[0].message.content}")
|
||||
print(f"Response content: {tool_response.choices[0].message.content}")
|
||||
if tool_response.choices[0].message.tool_calls:
|
||||
for tool_call in tool_response.choices[0].message.tool_calls:
|
||||
print(f"Tool Called: {tool_call.tool_name}")
|
||||
print(f"Arguments: {tool_call.arguments}")
|
||||
print(f"Tool Called: {tool_call.function.name}")
|
||||
print(f"Arguments: {tool_call.function.arguments}")
|
||||
```
|
||||
|
||||
### Structured Output Example
|
||||
|
|
@ -105,33 +108,26 @@ if tool_response.choices[0].message.tool_calls:
|
|||
The following example shows how to do structured output for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||
|
||||
person_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"age": {"type": "number"},
|
||||
"occupation": {"type": "string"},
|
||||
},
|
||||
"required": ["name", "age", "occupation"],
|
||||
}
|
||||
|
||||
response_format = JsonSchemaResponseFormat(
|
||||
type=ResponseFormatType.json_schema, json_schema=person_schema
|
||||
)
|
||||
|
||||
structured_response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model="nvidia/meta/llama-3.1-8b-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
|
||||
}
|
||||
],
|
||||
response_format=response_format,
|
||||
extra_body={"nvext": {"guided_json": person_schema}},
|
||||
)
|
||||
|
||||
print(f"Structured Response: {structured_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
|
|
@ -139,16 +135,13 @@ print(f"Structured Response: {structured_response.choices[0].message.content}")
|
|||
|
||||
The following example shows how to create embeddings for an NVIDIA NIM.
|
||||
|
||||
> [!NOTE]
|
||||
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
|
||||
|
||||
```python
|
||||
response = client.inference.embeddings(
|
||||
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
contents=["What is the capital of France?"],
|
||||
task_type="query",
|
||||
response = client.embeddings.create(
|
||||
model="nvidia/nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
input=["What is the capital of France?"],
|
||||
extra_body={"input_type": "query"},
|
||||
)
|
||||
print(f"Embeddings: {response.embeddings}")
|
||||
print(f"Embeddings: {response.data}")
|
||||
```
|
||||
|
||||
### Vision Language Models Example
|
||||
|
|
@ -166,15 +159,15 @@ image_path = {path_to_the_image}
|
|||
demo_image_b64 = load_image_as_base64(image_path)
|
||||
|
||||
vlm_response = client.chat.completions.create(
|
||||
model="nvidia/vila",
|
||||
model="nvidia/meta/llama-3.2-11b-vision-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": demo_image_b64,
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{demo_image_b64}",
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from .config import NVIDIAConfig
|
|||
|
||||
|
||||
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
|
||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||
# import dynamically so `llama stack list-deps` does not fail due to missing dependencies
|
||||
from .nvidia import NVIDIAInferenceAdapter
|
||||
|
||||
if not isinstance(config, NVIDIAConfig):
|
||||
|
|
|
|||
|
|
@ -5,14 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -27,15 +19,6 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of ModelRegistryHelper.check_model_availability(). It also
|
||||
must come before Inference to ensure that OpenAIMixin methods are available
|
||||
in the Inference interface.
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
|
||||
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
|
||||
"""
|
||||
|
||||
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
|
|
@ -76,50 +59,3 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
OpenAI-compatible embeddings for NVIDIA NIM.
|
||||
|
||||
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
|
||||
We default this to "query" to ensure requests succeed when using the
|
||||
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
|
||||
`task_type='document'`.
|
||||
"""
|
||||
extra_body: dict[str, object] = {"input_type": "query"}
|
||||
logger.warning(
|
||||
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
|
||||
"For passage embeddings, use the embeddings API with task_type='document'."
|
||||
)
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
input=params.input,
|
||||
encoding_format=params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
||||
dimensions=params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
||||
user=params.user if params.user is not None else NOT_GIVEN,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service
|
|||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ This provider enables safety checks and guardrails for LLM interactions using NV
|
|||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
|
|
|||
|
|
@ -13,26 +13,19 @@ import chromadb
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
|
||||
|
||||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io::chroma")
|
||||
|
|
@ -40,7 +33,7 @@ log = get_logger(name=__name__, category="vector_io::chroma")
|
|||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:chroma:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:chroma:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:chroma:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
|
||||
|
|
@ -70,19 +63,13 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
|
||||
await maybe_await(
|
||||
self.collection.add(
|
||||
documents=[chunk.model_dump_json() for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids)
|
||||
)
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
n_results=k,
|
||||
include=["documents", "distances"],
|
||||
query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"]
|
||||
)
|
||||
)
|
||||
distances = results["distances"][0]
|
||||
|
|
@ -160,6 +147,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a single chunk from the Chroma collection by its ID."""
|
||||
ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion]
|
||||
|
|
@ -227,11 +215,11 @@ class ChromaIndex(EmbeddingIndex):
|
|||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
|
|
@ -240,11 +228,11 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.inference_api = inference_api
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.vector_db_store = self.kvstore
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
self.vector_store_table = self.kvstore
|
||||
|
||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
|
|
@ -264,70 +252,58 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
collection = await maybe_await(
|
||||
self.client.get_or_create_collection(
|
||||
name=vector_db.identifier,
|
||||
metadata={"vector_db": vector_db.model_dump_json()},
|
||||
name=vector_store.identifier, metadata={"vector_store": vector_store.model_dump_json()}
|
||||
)
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db, ChromaIndex(self.client, collection), self.inference_api
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_store_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack")
|
||||
collection = await maybe_await(self.client.get_collection(vector_db_id))
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Llama Stack")
|
||||
collection = await maybe_await(self.client.get_collection(vector_store_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
||||
self.cache[vector_db_id] = index
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
|
||||
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a Chroma vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,21 +8,21 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChromaVectorIOConfig(BaseModel):
|
||||
url: str | None
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
persistence: KVStoreReference = Field(description="Config for KV store backend")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="chroma_remote_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::chroma_remote",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
|
|||
from .milvus import MilvusVectorIOAdapter
|
||||
|
||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
uri: str = Field(description="The URI of the Milvus server")
|
||||
token: str | None = Field(description="The token of the Milvus server")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
persistence: KVStoreReference = Field(description="Config for KV store backend")
|
||||
|
||||
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
||||
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
||||
|
|
@ -28,8 +28,8 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
return {
|
||||
"uri": "${env.MILVUS_ENDPOINT}",
|
||||
"token": "${env.MILVUS_TOKEN}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_remote_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::milvus_remote",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,16 +12,12 @@ from numpy.typing import NDArray
|
|||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -30,7 +26,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
RERANKER_TYPE_WEIGHTED,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
|
||||
|
||||
|
|
@ -39,7 +35,7 @@ from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
|||
logger = get_logger(name=__name__, category="vector_io::milvus")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:milvus:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:milvus:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:milvus:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:milvus:{VERSION}::"
|
||||
|
|
@ -73,46 +69,23 @@ class MilvusIndex(EmbeddingIndex):
|
|||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
field_name="chunk_id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=100,
|
||||
)
|
||||
schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
|
||||
schema.add_field(
|
||||
field_name="content",
|
||||
datatype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_analyzer=True, # Enable text analysis for BM25
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=len(embeddings[0]),
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="chunk_content",
|
||||
datatype=DataType.JSON,
|
||||
)
|
||||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0]))
|
||||
schema.add_field(field_name="chunk_content", datatype=DataType.JSON)
|
||||
# Add sparse vector field for BM25 (required by the function)
|
||||
schema.add_field(
|
||||
field_name="sparse",
|
||||
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||
)
|
||||
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
|
||||
|
||||
# Create indexes
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="FLAT",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
|
||||
# Add index for sparse field (required by BM25 function)
|
||||
index_params.add_index(
|
||||
field_name="sparse",
|
||||
index_type="SPARSE_INVERTED_INDEX",
|
||||
metric_type="BM25",
|
||||
)
|
||||
index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25")
|
||||
|
||||
# Add BM25 function for full-text search
|
||||
bm25_function = Function(
|
||||
|
|
@ -143,11 +116,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
}
|
||||
)
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
|
@ -166,12 +135,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
scores = [res["distance"] for res in search_res[0]]
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||
"""
|
||||
|
|
@ -209,12 +173,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
# Fallback to simple text search
|
||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||
|
||||
async def _fallback_keyword_search(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Fallback to simple text search when BM25 search is not available.
|
||||
"""
|
||||
|
|
@ -302,7 +261,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
raise
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
||||
|
|
@ -314,28 +273,28 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
for vector_store_data in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store,
|
||||
index=MilvusIndex(
|
||||
client=self.client,
|
||||
collection_name=vector_db.identifier,
|
||||
collection_name=vector_store.identifier,
|
||||
consistency_level=self.config.consistency_level,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
|
|
@ -352,72 +311,61 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
consistency_level = self.config.consistency_level
|
||||
else:
|
||||
consistency_level = "Strong"
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id in self.cache:
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a chunk from a milvus vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
|
|||
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .pgvector import PGVectorVectorIOAdapter
|
||||
|
||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -22,7 +19,9 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
db: str | None = Field(default="postgres")
|
||||
user: str | None = Field(default="postgres")
|
||||
password: str | None = Field(default="mysecretpassword")
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
persistence: KVStoreReference | None = Field(
|
||||
description="Config for KV store backend (SQLite only for now)", default=None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
|
|
@ -41,8 +40,8 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
"db": db,
|
||||
"user": user,
|
||||
"password": password,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="pgvector_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::pgvector",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,27 +14,17 @@ from psycopg2.extras import Json, execute_values
|
|||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
|
@ -42,7 +32,7 @@ from .config import PGVectorVectorIOConfig
|
|||
log = get_logger(name=__name__, category="vector_io::pgvector")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:pgvector:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
|
||||
|
|
@ -89,13 +79,13 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
vector_store: VectorStore,
|
||||
dimension: int,
|
||||
conn: psycopg2.extensions.connection,
|
||||
kvstore: KVStore | None = None,
|
||||
distance_metric: str = "COSINE",
|
||||
):
|
||||
self.vector_db = vector_db
|
||||
self.vector_store = vector_store
|
||||
self.dimension = dimension
|
||||
self.conn = conn
|
||||
self.kvstore = kvstore
|
||||
|
|
@ -107,9 +97,9 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
try:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
|
||||
# SQL doesn't allow hyphens in table names, and vector_store.identifier may contain hyphens
|
||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = sanitize_collection_name(self.vector_db.identifier)
|
||||
sanitized_identifier = sanitize_collection_name(self.vector_store.identifier)
|
||||
self.table_name = f"vs_{sanitized_identifier}"
|
||||
|
||||
cur.execute(
|
||||
|
|
@ -132,8 +122,8 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
|
||||
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e
|
||||
log.exception(f"Error creating PGVectorIndex for vector_store: {self.vector_store.identifier}")
|
||||
raise RuntimeError(f"Error creating PGVectorIndex for vector_store: {self.vector_store.identifier}") from e
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
|
|
@ -204,12 +194,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
|
||||
|
||||
|
|
@ -316,7 +301,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
"""Remove a chunk from the PostgreSQL table."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids))
|
||||
|
||||
def get_pgvector_search_function(self) -> str:
|
||||
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
|
||||
|
|
@ -338,24 +323,21 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: PGVectorVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
files_api: Files | None = None,
|
||||
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
try:
|
||||
|
|
@ -393,71 +375,59 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
# Persist vector DB metadata in the KV store
|
||||
assert self.kvstore is not None
|
||||
# Upsert model metadata in Postgres
|
||||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||
upsert_models(self.conn, [(vector_store.identifier, vector_store)])
|
||||
|
||||
# Create and cache the PGVector index table for the vector DB
|
||||
pgvector_index = PGVectorIndex(
|
||||
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=pgvector_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
# Remove provider index and cache
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
if vector_store_id in self.cache:
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
# Delete vector DB metadata from KV store
|
||||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
|
||||
await index.initialize()
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a chunk from a PostgreSQL vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from .config import QdrantVectorIOConfig
|
|||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .qdrant import QdrantVectorIOAdapter
|
||||
|
||||
files_api = deps.get(Api.files)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -27,14 +24,14 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
prefix: str | None = None
|
||||
timeout: int | None = None
|
||||
host: str | None = None
|
||||
kvstore: KVStoreConfig
|
||||
persistence: KVStoreReference
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.QDRANT_API_KEY:=}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="qdrant_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::qdrant_remote",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from qdrant_client.models import PointStruct
|
|||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
|
|
@ -24,16 +23,13 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategy,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
|
|
@ -42,7 +38,7 @@ CHUNK_ID_KEY = "_chunk_id"
|
|||
|
||||
# KV store prefixes for vector databases
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:qdrant:{VERSION}::"
|
||||
|
||||
|
||||
def convert_id(_id: str) -> str:
|
||||
|
|
@ -98,8 +94,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
|
||||
try:
|
||||
await self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=models.PointIdsList(points=chunk_ids),
|
||||
collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
|
||||
|
|
@ -132,12 +127,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||
|
||||
async def query_hybrid(
|
||||
|
|
@ -155,11 +145,11 @@ class QdrantIndex(EmbeddingIndex):
|
|||
await self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
|
||||
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
|
|
@ -167,26 +157,24 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.client: AsyncQdrantClient = None
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
self._qdrant_lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
client_config = self.config.model_dump(exclude_none=True, exclude={"kvstore"})
|
||||
client_config = self.config.model_dump(exclude_none=True, exclude={"persistence"})
|
||||
self.client = AsyncQdrantClient(**client_config)
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
QdrantIndex(self.client, vector_db.identifier),
|
||||
self.inference_api,
|
||||
for vector_store_data in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
@ -194,68 +182,57 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
assert self.kvstore is not None
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=QdrantIndex(self.client, vector_db.identifier),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=QdrantIndex(self.client, vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id in self.cache:
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise ValueError(f"Vector DB not found {vector_db_id}")
|
||||
if self.vector_store_table is None:
|
||||
raise ValueError(f"Vector DB not found {vector_store_id}")
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
|
|
@ -276,7 +253,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a Qdrant vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,6 @@ from .config import WeaviateVectorIOConfig
|
|||
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .weaviate import WeaviateVectorIOAdapter
|
||||
|
||||
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -8,10 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -19,19 +16,17 @@ from llama_stack.schema_utils import json_schema_type
|
|||
class WeaviateVectorIOConfig(BaseModel):
|
||||
weaviate_api_key: str | None = Field(description="The API key for the Weaviate instance", default=None)
|
||||
weaviate_cluster_url: str | None = Field(description="The URL of the Weaviate cluster", default="localhost:8080")
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
persistence: KVStoreReference | None = Field(
|
||||
description="Config for KV store backend (SQLite only for now)", default=None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"weaviate_api_key": None,
|
||||
"weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="weaviate_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::weaviate",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,22 +14,21 @@ from weaviate.classes.query import Filter, HybridFusion
|
|||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
|
||||
|
||||
|
|
@ -38,7 +37,7 @@ from .config import WeaviateVectorIOConfig
|
|||
log = get_logger(name=__name__, category="vector_io::weaviate")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:weaviate:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
|
||||
|
|
@ -46,12 +45,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
|||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(
|
||||
self,
|
||||
client: weaviate.WeaviateClient,
|
||||
collection_name: str,
|
||||
kvstore: KVStore | None = None,
|
||||
):
|
||||
def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None):
|
||||
self.client = client
|
||||
self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True)
|
||||
self.kvstore = kvstore
|
||||
|
|
@ -106,9 +100,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
try:
|
||||
results = collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(distance=True),
|
||||
near_vector=embedding.tolist(), limit=k, return_metadata=wvc.query.MetadataQuery(distance=True)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client vector search failed: {e}")
|
||||
|
|
@ -151,12 +143,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs BM25-based keyword search using Weaviate's built-in full-text search.
|
||||
Args:
|
||||
|
|
@ -173,9 +160,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
# Perform BM25 keyword search on chunk_content field
|
||||
try:
|
||||
results = collection.query.bm25(
|
||||
query=query_string,
|
||||
limit=k,
|
||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||
query=query_string, limit=k, return_metadata=wvc.query.MetadataQuery(score=True)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Weaviate client keyword search failed: {e}")
|
||||
|
|
@ -272,24 +257,14 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(
|
||||
OpenAIVectorStoreMixin,
|
||||
VectorIO,
|
||||
NeedsRequestProviderData,
|
||||
VectorDBsProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: WeaviateVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate):
|
||||
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
def _get_client(self) -> weaviate.WeaviateClient:
|
||||
|
|
@ -297,10 +272,7 @@ class WeaviateVectorIOAdapter(
|
|||
log.info("Using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
client = weaviate.connect_to_local(host=host, port=port)
|
||||
else:
|
||||
log.info("Using Weaviate remote cluster with URL")
|
||||
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
||||
|
|
@ -316,8 +288,8 @@ class WeaviateVectorIOAdapter(
|
|||
async def initialize(self) -> None:
|
||||
"""Set up KV store and load existing vector DBs and OpenAI vector stores."""
|
||||
# Initialize KV store for metadata if configured
|
||||
if self.config.kvstore is not None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
if self.config.persistence is not None:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
else:
|
||||
self.kvstore = None
|
||||
log.info("No kvstore configured, registry will not persist across restarts")
|
||||
|
|
@ -328,17 +300,11 @@ class WeaviateVectorIOAdapter(
|
|||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for raw in stored:
|
||||
vector_db = VectorDB.model_validate_json(raw)
|
||||
vector_store = VectorStore.model_validate_json(raw)
|
||||
client = self._get_client()
|
||||
idx = WeaviateIndex(
|
||||
client=client,
|
||||
collection_name=vector_db.identifier,
|
||||
kvstore=self.kvstore,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=idx,
|
||||
inference_api=self.inference_api,
|
||||
idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store=vector_store, index=idx, inference_api=self.inference_api
|
||||
)
|
||||
|
||||
# Load OpenAI vector stores metadata into cache
|
||||
|
|
@ -350,90 +316,74 @@ class WeaviateVectorIOAdapter(
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||
sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
|
||||
# Create collection if it doesn't exist
|
||||
if not client.collections.exists(sanitized_collection_name):
|
||||
client.collections.create(
|
||||
name=sanitized_collection_name,
|
||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
wvc.config.Property(name="chunk_content", data_type=wvc.config.DataType.TEXT),
|
||||
],
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db,
|
||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||
self.inference_api,
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
|
||||
)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
sanitized_collection_name = sanitize_collection_name(vector_store_id, weaviate_format=True)
|
||||
if vector_store_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
return
|
||||
client.collections.delete(sanitized_collection_name)
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||
sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
|
||||
if not client.collections.exists(sanitized_collection_name):
|
||||
raise ValueError(f"Collection with name `{sanitized_collection_name}` not found")
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,12 @@
|
|||
|
||||
import asyncio
|
||||
import base64
|
||||
import platform
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -24,6 +27,8 @@ from llama_stack.apis.inference import (
|
|||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
DARWIN = "Darwin"
|
||||
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
|
@ -83,6 +88,13 @@ class SentenceTransformerEmbeddingMixin:
|
|||
def _load_model():
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
platform_name = platform.system()
|
||||
if platform_name == DARWIN:
|
||||
# PyTorch's OpenMP kernels can segfault on macOS when spawned from background
|
||||
# threads with the default parallel settings, so force a single-threaded CPU run.
|
||||
log.debug(f"Constraining torch threads on {platform_name} to a single worker")
|
||||
torch.set_num_threads(1)
|
||||
|
||||
return SentenceTransformer(model, trust_remote_code=True)
|
||||
|
||||
loaded_model = await asyncio.to_thread(_load_model)
|
||||
|
|
|
|||
|
|
@ -15,12 +15,13 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
|
||||
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
|
@ -28,33 +29,32 @@ logger = get_logger(name=__name__, category="inference")
|
|||
class InferenceStore:
|
||||
def __init__(
|
||||
self,
|
||||
config: InferenceStoreConfig | SqlStoreConfig,
|
||||
reference: InferenceStoreReference,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
# Handle backward compatibility
|
||||
if not isinstance(config, InferenceStoreConfig):
|
||||
# Legacy: SqlStoreConfig passed directly as config
|
||||
config = InferenceStoreConfig(
|
||||
sql_store_config=config,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.sql_store_config = config.sql_store_config
|
||||
self.reference = reference
|
||||
self.sql_store = None
|
||||
self.policy = policy
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = config.max_write_queue_size
|
||||
self._num_writers: int = max(1, config.num_writers)
|
||||
self._max_write_queue_size: int = reference.max_write_queue_size
|
||||
self._num_writers: int = max(1, reference.num_writers)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
backend_name = self.reference.backend
|
||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
|
|
|
|||
|
|
@ -168,13 +168,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
is used instead of any config API key.
|
||||
"""
|
||||
|
||||
api_key = self.get_api_key()
|
||||
|
||||
if self.provider_data_api_key_field:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
||||
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
||||
|
||||
api_key = self._get_api_key_from_config_or_provider_data()
|
||||
if not api_key:
|
||||
message = "API key not provided."
|
||||
if self.provider_data_api_key_field:
|
||||
|
|
@ -187,6 +181,16 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
**self.get_extra_client_params(),
|
||||
)
|
||||
|
||||
def _get_api_key_from_config_or_provider_data(self) -> str | None:
|
||||
api_key = self.get_api_key()
|
||||
|
||||
if self.provider_data_api_key_field:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
||||
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
||||
|
||||
return api_key
|
||||
|
||||
async def _get_provider_model_id(self, model: str) -> str:
|
||||
"""
|
||||
Get the provider-specific model ID from the model store.
|
||||
|
|
@ -387,6 +391,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
self._model_cache = {}
|
||||
|
||||
api_key = self._get_api_key_from_config_or_provider_data()
|
||||
if not api_key:
|
||||
logger.debug(f"{self.__class__.__name__}.list_provider_model_ids() disabled because API key not provided")
|
||||
return None
|
||||
|
||||
try:
|
||||
iterable = await self.list_provider_model_ids()
|
||||
except Exception as e:
|
||||
|
|
@ -435,7 +444,8 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
# First check if the model is pre-registered in the model store
|
||||
if hasattr(self, "model_store") and self.model_store:
|
||||
if await self.model_store.has_model(model):
|
||||
qualified_model = f"{self.__provider_id__}/{model}" # type: ignore[attr-defined]
|
||||
if await self.model_store.has_model(qualified_model):
|
||||
return True
|
||||
|
||||
# Then check the provider's dynamic model cache
|
||||
|
|
|
|||
|
|
@ -4,143 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
|
||||
class KVStoreType(Enum):
|
||||
redis = "redis"
|
||||
sqlite = "sqlite"
|
||||
postgres = "postgres"
|
||||
mongodb = "mongodb"
|
||||
|
||||
|
||||
class CommonConfig(BaseModel):
|
||||
namespace: str | None = Field(
|
||||
default=None,
|
||||
description="All keys will be prefixed with this namespace",
|
||||
)
|
||||
|
||||
|
||||
class RedisKVStoreConfig(CommonConfig):
|
||||
type: Literal["redis"] = KVStoreType.redis.value
|
||||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"redis://{self.host}:{self.port}"
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["redis"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
return {
|
||||
"type": "redis",
|
||||
"host": "${env.REDIS_HOST:=localhost}",
|
||||
"port": "${env.REDIS_PORT:=6379}",
|
||||
}
|
||||
|
||||
|
||||
class SqliteKVStoreConfig(CommonConfig):
|
||||
type: Literal["sqlite"] = KVStoreType.sqlite.value
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
|
||||
description="File path for the sqlite database",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["aiosqlite"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
|
||||
return {
|
||||
"type": "sqlite",
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
||||
|
||||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal["postgres"] = KVStoreType.postgres.value
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
ssl_mode: str | None = None
|
||||
ca_cert_path: str | None = None
|
||||
table_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, table_name: str = "llamastack_kvstore", **kwargs):
|
||||
return {
|
||||
"type": "postgres",
|
||||
"host": "${env.POSTGRES_HOST:=localhost}",
|
||||
"port": "${env.POSTGRES_PORT:=5432}",
|
||||
"db": "${env.POSTGRES_DB:=llamastack}",
|
||||
"user": "${env.POSTGRES_USER:=llamastack}",
|
||||
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
|
||||
"table_name": "${env.POSTGRES_TABLE_NAME:=" + table_name + "}",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@field_validator("table_name")
|
||||
def validate_table_name(cls, v: str) -> str:
|
||||
# PostgreSQL identifiers rules:
|
||||
# - Must start with a letter or underscore
|
||||
# - Can contain letters, numbers, and underscores
|
||||
# - Maximum length is 63 bytes
|
||||
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$"
|
||||
if not re.match(pattern, v):
|
||||
raise ValueError(
|
||||
"Invalid table name. Must start with letter or underscore and contain only letters, numbers, and underscores"
|
||||
)
|
||||
if len(v) > 63:
|
||||
raise ValueError("Table name must be less than 63 characters")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["psycopg2-binary"]
|
||||
|
||||
|
||||
class MongoDBKVStoreConfig(CommonConfig):
|
||||
type: Literal["mongodb"] = KVStoreType.mongodb.value
|
||||
host: str = "localhost"
|
||||
port: int = 27017
|
||||
db: str = "llamastack"
|
||||
user: str | None = None
|
||||
password: str | None = None
|
||||
collection_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["pymongo"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
|
||||
return {
|
||||
"type": "mongodb",
|
||||
"host": "${env.MONGODB_HOST:=localhost}",
|
||||
"port": "${env.MONGODB_PORT:=5432}",
|
||||
"db": "${env.MONGODB_DB}",
|
||||
"user": "${env.MONGODB_USER}",
|
||||
"password": "${env.MONGODB_PASSWORD}",
|
||||
"collection_name": "${env.MONGODB_COLLECTION_NAME:=" + collection_name + "}",
|
||||
}
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
MongoDBKVStoreConfig,
|
||||
PostgresKVStoreConfig,
|
||||
RedisKVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
StorageBackendType,
|
||||
)
|
||||
|
||||
KVStoreConfig = Annotated[
|
||||
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig,
|
||||
Field(discriminator="type", default=KVStoreType.sqlite.value),
|
||||
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig, Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -148,13 +25,13 @@ def get_pip_packages(store_config: dict | KVStoreConfig) -> list[str]:
|
|||
"""Get pip packages for KV store config, handling both dict and object cases."""
|
||||
if isinstance(store_config, dict):
|
||||
store_type = store_config.get("type")
|
||||
if store_type == "sqlite":
|
||||
if store_type == StorageBackendType.KV_SQLITE.value:
|
||||
return SqliteKVStoreConfig.pip_packages()
|
||||
elif store_type == "postgres":
|
||||
elif store_type == StorageBackendType.KV_POSTGRES.value:
|
||||
return PostgresKVStoreConfig.pip_packages()
|
||||
elif store_type == "redis":
|
||||
elif store_type == StorageBackendType.KV_REDIS.value:
|
||||
return RedisKVStoreConfig.pip_packages()
|
||||
elif store_type == "mongodb":
|
||||
elif store_type == StorageBackendType.KV_MONGODB.value:
|
||||
return MongoDBKVStoreConfig.pip_packages()
|
||||
else:
|
||||
raise ValueError(f"Unknown KV store type: {store_type}")
|
||||
|
|
|
|||
|
|
@ -4,9 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
|
||||
|
||||
from .api import KVStore
|
||||
from .config import KVStoreConfig, KVStoreType
|
||||
from .config import KVStoreConfig
|
||||
|
||||
|
||||
def kvstore_dependencies():
|
||||
|
|
@ -44,20 +52,41 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
del self._store[key]
|
||||
|
||||
|
||||
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
||||
if config.type == KVStoreType.redis.value:
|
||||
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
|
||||
|
||||
|
||||
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available KV store backends for reference resolution."""
|
||||
global _KVSTORE_BACKENDS
|
||||
|
||||
_KVSTORE_BACKENDS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_KVSTORE_BACKENDS[name] = cfg
|
||||
|
||||
|
||||
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
||||
backend_name = reference.backend
|
||||
|
||||
backend_config = _KVSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
|
||||
|
||||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
|
||||
if config.type == StorageBackendType.KV_REDIS.value:
|
||||
from .redis import RedisKVStoreImpl
|
||||
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif config.type == KVStoreType.sqlite.value:
|
||||
elif config.type == StorageBackendType.KV_SQLITE.value:
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif config.type == KVStoreType.postgres.value:
|
||||
elif config.type == StorageBackendType.KV_POSTGRES.value:
|
||||
from .postgres import PostgresKVStoreImpl
|
||||
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif config.type == KVStoreType.mongodb.value:
|
||||
elif config.type == StorageBackendType.KV_MONGODB.value:
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from pydantic import TypeAdapter
|
|||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files, OpenAIFileObject
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
|
|
@ -43,6 +42,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -52,6 +52,8 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
EMBEDDING_DIMENSION = 768
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
# Constants for OpenAI vector stores
|
||||
|
|
@ -61,7 +63,7 @@ MAX_CONCURRENT_FILES_PER_BATCH = 3 # Maximum concurrent file processing within
|
|||
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
|
||||
|
|
@ -77,7 +79,11 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
# Implementing classes should call super().__init__() in their __init__ method
|
||||
# to properly initialize the mixin attributes.
|
||||
def __init__(self, files_api: Files | None = None, kvstore: KVStore | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
files_api: Files | None = None,
|
||||
kvstore: KVStore | None = None,
|
||||
):
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self.files_api = files_api
|
||||
|
|
@ -315,12 +321,12 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
"""Register a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
"""Unregister a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
|
|
@ -349,35 +355,53 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
created_at = int(time.time())
|
||||
|
||||
# Extract llama-stack-specific parameters from extra_body
|
||||
extra = params.model_extra or {}
|
||||
provider_vector_db_id = extra.get("provider_vector_db_id")
|
||||
embedding_model = extra.get("embedding_model")
|
||||
embedding_dimension = extra.get("embedding_dimension", 768)
|
||||
extra_body = params.model_extra or {}
|
||||
metadata = params.metadata or {}
|
||||
|
||||
provider_vector_store_id = extra_body.get("provider_vector_store_id")
|
||||
|
||||
# Use embedding info from metadata if available, otherwise from extra_body
|
||||
if metadata.get("embedding_model"):
|
||||
# If either is in metadata, use metadata as source
|
||||
embedding_model = metadata.get("embedding_model")
|
||||
embedding_dimension = (
|
||||
int(metadata["embedding_dimension"]) if metadata.get("embedding_dimension") else EMBEDDING_DIMENSION
|
||||
)
|
||||
logger.debug(
|
||||
f"Using embedding config from metadata (takes precedence over extra_body): model='{embedding_model}', dimension={embedding_dimension}"
|
||||
)
|
||||
else:
|
||||
embedding_model = extra_body.get("embedding_model")
|
||||
embedding_dimension = extra_body.get("embedding_dimension", EMBEDDING_DIMENSION)
|
||||
logger.debug(
|
||||
f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}"
|
||||
)
|
||||
|
||||
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
|
||||
provider_id = extra.get("provider_id") or getattr(self, "__provider_id__", None)
|
||||
# Derive the canonical vector_db_id (allow override, else generate)
|
||||
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
|
||||
# Derive the canonical vector_store_id (allow override, else generate)
|
||||
vector_store_id = provider_vector_store_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
|
||||
if embedding_model is None:
|
||||
raise ValueError("Embedding model is required")
|
||||
raise ValueError("embedding_model is required")
|
||||
|
||||
# Embedding dimension is required (defaulted to 768 if not provided)
|
||||
if embedding_dimension is None:
|
||||
raise ValueError("Embedding dimension is required")
|
||||
|
||||
# Register the VectorDB backing this vector store
|
||||
# Register the VectorStore backing this vector store
|
||||
if provider_id is None:
|
||||
raise ValueError("Provider ID is required but was not provided")
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier=vector_db_id,
|
||||
# call to the provider to create any index, etc.
|
||||
vector_store = VectorStore(
|
||||
identifier=vector_store_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
embedding_model=embedding_model,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=vector_db_id,
|
||||
vector_db_name=params.name,
|
||||
provider_resource_id=vector_store_id,
|
||||
vector_store_name=params.name,
|
||||
)
|
||||
await self.register_vector_db(vector_db)
|
||||
await self.register_vector_store(vector_store)
|
||||
|
||||
# Create OpenAI vector store metadata
|
||||
status = "completed"
|
||||
|
|
@ -391,7 +415,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
total=0,
|
||||
)
|
||||
store_info: dict[str, Any] = {
|
||||
"id": vector_db_id,
|
||||
"id": vector_store_id,
|
||||
"object": "vector_store",
|
||||
"created_at": created_at,
|
||||
"name": params.name,
|
||||
|
|
@ -406,26 +430,25 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
}
|
||||
|
||||
# Add provider information to metadata if provided
|
||||
metadata = params.metadata or {}
|
||||
if provider_id:
|
||||
metadata["provider_id"] = provider_id
|
||||
if provider_vector_db_id:
|
||||
metadata["provider_vector_db_id"] = provider_vector_db_id
|
||||
if provider_vector_store_id:
|
||||
metadata["provider_vector_store_id"] = provider_vector_store_id
|
||||
store_info["metadata"] = metadata
|
||||
|
||||
# Save to persistent storage (provider-specific)
|
||||
await self._save_openai_vector_store(vector_db_id, store_info)
|
||||
await self._save_openai_vector_store(vector_store_id, store_info)
|
||||
|
||||
# Store in memory cache
|
||||
self.openai_vector_stores[vector_db_id] = store_info
|
||||
self.openai_vector_stores[vector_store_id] = store_info
|
||||
|
||||
# Now that our vector store is created, attach any files that were provided
|
||||
file_ids = params.file_ids or []
|
||||
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids]
|
||||
tasks = [self.openai_attach_file_to_vector_store(vector_store_id, file_id) for file_id in file_ids]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Get the updated store info and return it
|
||||
store_info = self.openai_vector_stores[vector_db_id]
|
||||
store_info = self.openai_vector_stores[vector_store_id]
|
||||
return VectorStoreObject.model_validate(store_info)
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
|
|
@ -535,7 +558,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
# Also delete the underlying vector DB
|
||||
try:
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
await self.unregister_vector_store(vector_store_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ from llama_stack.apis.common.content_types import (
|
|||
)
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
@ -187,7 +187,7 @@ def make_overlapped_chunks(
|
|||
updated_timestamp=int(time.time()),
|
||||
chunk_window=chunk_window,
|
||||
chunk_tokenizer=default_tokenizer,
|
||||
chunk_embedding_model=None, # This will be set in `VectorDBWithIndex.insert_chunks`
|
||||
chunk_embedding_model=None, # This will be set in `VectorStoreWithIndex.insert_chunks`
|
||||
content_token_count=len(toks),
|
||||
metadata_token_count=len(metadata_tokens),
|
||||
)
|
||||
|
|
@ -255,8 +255,8 @@ class EmbeddingIndex(ABC):
|
|||
|
||||
|
||||
@dataclass
|
||||
class VectorDBWithIndex:
|
||||
vector_db: VectorDB
|
||||
class VectorStoreWithIndex:
|
||||
vector_store: VectorStore
|
||||
index: EmbeddingIndex
|
||||
inference_api: Api.inference
|
||||
|
||||
|
|
@ -269,14 +269,14 @@ class VectorDBWithIndex:
|
|||
if c.embedding is None:
|
||||
chunks_to_embed.append(c)
|
||||
if c.chunk_metadata:
|
||||
c.chunk_metadata.chunk_embedding_model = self.vector_db.embedding_model
|
||||
c.chunk_metadata.chunk_embedding_dimension = self.vector_db.embedding_dimension
|
||||
c.chunk_metadata.chunk_embedding_model = self.vector_store.embedding_model
|
||||
c.chunk_metadata.chunk_embedding_dimension = self.vector_store.embedding_dimension
|
||||
else:
|
||||
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
|
||||
_validate_embedding(c.embedding, i, self.vector_store.embedding_dimension)
|
||||
|
||||
if chunks_to_embed:
|
||||
params = OpenAIEmbeddingsRequestWithExtraBody(
|
||||
model=self.vector_db.embedding_model,
|
||||
model=self.vector_store.embedding_model,
|
||||
input=[c.content for c in chunks_to_embed],
|
||||
)
|
||||
resp = await self.inference_api.openai_embeddings(params)
|
||||
|
|
@ -319,7 +319,7 @@ class VectorDBWithIndex:
|
|||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
params = OpenAIEmbeddingsRequestWithExtraBody(
|
||||
model=self.vector_db.embedding_model,
|
||||
model=self.vector_store.embedding_model,
|
||||
input=[query_string],
|
||||
)
|
||||
embeddings_response = await self.inference_api.openai_embeddings(params)
|
||||
|
|
|
|||
|
|
@ -18,13 +18,13 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectWithInput,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference, StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl
|
||||
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
|
@ -45,39 +45,38 @@ class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
|
|||
class ResponsesStore:
|
||||
def __init__(
|
||||
self,
|
||||
config: ResponsesStoreConfig | SqlStoreConfig,
|
||||
reference: ResponsesStoreReference | SqlStoreReference,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
# Handle backward compatibility
|
||||
if not isinstance(config, ResponsesStoreConfig):
|
||||
# Legacy: SqlStoreConfig passed directly as config
|
||||
config = ResponsesStoreConfig(
|
||||
sql_store_config=config,
|
||||
)
|
||||
if isinstance(reference, ResponsesStoreReference):
|
||||
self.reference = reference
|
||||
else:
|
||||
self.reference = ResponsesStoreReference(**reference.model_dump())
|
||||
|
||||
self.config = config
|
||||
self.sql_store_config = config.sql_store_config
|
||||
if not self.sql_store_config:
|
||||
self.sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
)
|
||||
self.sql_store = None
|
||||
self.policy = policy
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||
self.sql_store = None
|
||||
self.enable_write_queue = True
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: (
|
||||
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
|
||||
) = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = config.max_write_queue_size
|
||||
self._num_writers: int = max(1, config.num_writers)
|
||||
self._max_write_queue_size: int = self.reference.max_write_queue_size
|
||||
self._num_writers: int = max(1, self.reference.num_writers)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
if backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
self.enable_write_queue = False
|
||||
await self.sql_store.create_table(
|
||||
"openai_responses",
|
||||
{
|
||||
|
|
@ -88,12 +87,20 @@ class ResponsesStore:
|
|||
},
|
||||
)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"conversation_messages",
|
||||
{
|
||||
"conversation_id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"messages": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
if self.enable_write_queue:
|
||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
|
|
@ -294,3 +301,54 @@ class ResponsesStore:
|
|||
items = items[:limit]
|
||||
|
||||
return ListOpenAIResponseInputItem(data=items)
|
||||
|
||||
async def store_conversation_messages(self, conversation_id: str, messages: list[OpenAIMessageParam]) -> None:
|
||||
"""Store messages for a conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param messages: List of OpenAI message parameters to store.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
# Serialize messages to dict format for JSON storage
|
||||
messages_data = [msg.model_dump() for msg in messages]
|
||||
|
||||
# Upsert: try insert first, update if exists
|
||||
try:
|
||||
await self.sql_store.insert(
|
||||
table="conversation_messages",
|
||||
data={"conversation_id": conversation_id, "messages": messages_data},
|
||||
)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_messages",
|
||||
data={"messages": messages_data},
|
||||
where={"conversation_id": conversation_id},
|
||||
)
|
||||
|
||||
logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}")
|
||||
|
||||
async def get_conversation_messages(self, conversation_id: str) -> list[OpenAIMessageParam] | None:
|
||||
"""Get stored messages for a conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: List of OpenAI message parameters, or None if no messages stored.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_messages",
|
||||
where={"conversation_id": conversation_id},
|
||||
)
|
||||
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
# Deserialize messages from JSON storage
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
return adapter.validate_python(record["messages"])
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ from llama_stack.core.access_control.conditions import ProtectedResource
|
|||
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.storage.datatypes import StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||
from .sqlstore import SqlStoreType
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
|
@ -82,8 +82,8 @@ class AuthorizedSqlStore:
|
|||
if not hasattr(self.sql_store, "config"):
|
||||
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
||||
|
||||
self.database_type = self.sql_store.config.type
|
||||
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
|
||||
self.database_type = self.sql_store.config.type.value
|
||||
if self.database_type not in [StorageBackendType.SQL_POSTGRES.value, StorageBackendType.SQL_SQLITE.value]:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _validate_sql_optimized_policy(self) -> None:
|
||||
|
|
@ -220,9 +220,9 @@ class AuthorizedSqlStore:
|
|||
Returns:
|
||||
SQL expression to extract JSON value
|
||||
"""
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
return f"{column}->'{path}'"
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
|
@ -237,9 +237,9 @@ class AuthorizedSqlStore:
|
|||
Returns:
|
||||
SQL expression to extract JSON value as text
|
||||
"""
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
return f"{column}->>'{path}'"
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
|
@ -248,10 +248,10 @@ class AuthorizedSqlStore:
|
|||
"""Get the SQL conditions for public access."""
|
||||
# Public records are records that have no owner_principal or access_attributes
|
||||
conditions = ["owner_principal = ''"]
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
# Postgres stores JSON null as 'null'
|
||||
conditions.append("access_attributes::text = 'null'")
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
conditions.append("access_attributes = 'null'")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
|
|
|||
|
|
@ -26,10 +26,10 @@ from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
|||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,90 +4,28 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageBackendConfig,
|
||||
StorageBackendType,
|
||||
)
|
||||
|
||||
from .api import SqlStore
|
||||
|
||||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
|
||||
class SqlStoreType(StrEnum):
|
||||
sqlite = "sqlite"
|
||||
postgres = "postgres"
|
||||
|
||||
|
||||
class SqlAlchemySqlStoreConfig(BaseModel):
|
||||
@property
|
||||
@abstractmethod
|
||||
def engine_str(self) -> str: ...
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
||||
)
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix()
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
|
||||
return {
|
||||
"type": "sqlite",
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return super().pip_packages() + ["aiosqlite"]
|
||||
|
||||
|
||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return super().pip_packages() + ["asyncpg"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs):
|
||||
return {
|
||||
"type": "postgres",
|
||||
"host": "${env.POSTGRES_HOST:=localhost}",
|
||||
"port": "${env.POSTGRES_PORT:=5432}",
|
||||
"db": "${env.POSTGRES_DB:=llamastack}",
|
||||
"user": "${env.POSTGRES_USER:=llamastack}",
|
||||
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
|
||||
}
|
||||
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
|
||||
Field(discriminator="type", default=SqlStoreType.sqlite.value),
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -95,9 +33,9 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
|
|||
"""Get pip packages for SQL store config, handling both dict and object cases."""
|
||||
if isinstance(store_config, dict):
|
||||
store_type = store_config.get("type")
|
||||
if store_type == "sqlite":
|
||||
if store_type == StorageBackendType.SQL_SQLITE.value:
|
||||
return SqliteSqlStoreConfig.pip_packages()
|
||||
elif store_type == "postgres":
|
||||
elif store_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
return PostgresSqlStoreConfig.pip_packages()
|
||||
else:
|
||||
raise ValueError(f"Unknown SQL store type: {store_type}")
|
||||
|
|
@ -105,12 +43,28 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
|
|||
return store_config.pip_packages()
|
||||
|
||||
|
||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||
if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
|
||||
def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
|
||||
backend_name = reference.backend
|
||||
|
||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
|
||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
impl = SqlAlchemySqlStoreImpl(config)
|
||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||
return SqlAlchemySqlStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {config.type}")
|
||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||
|
||||
return impl
|
||||
|
||||
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available SQL store backends for reference resolution."""
|
||||
global _SQLSTORE_BACKENDS
|
||||
|
||||
_SQLSTORE_BACKENDS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_SQLSTORE_BACKENDS[name] = cfg
|
||||
|
|
|
|||
|
|
@ -1,80 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span
|
||||
|
||||
|
||||
class TelemetryDatasetMixin:
|
||||
"""Mixin class that provides dataset-related functionality for telemetry providers."""
|
||||
|
||||
datasetio_api: DatasetIO | None
|
||||
|
||||
async def save_spans_to_dataset(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
attributes_to_save: list[str],
|
||||
dataset_id: str,
|
||||
max_depth: int | None = None,
|
||||
) -> None:
|
||||
if self.datasetio_api is None:
|
||||
raise RuntimeError("DatasetIO API not available")
|
||||
|
||||
spans = await self.query_spans(
|
||||
attribute_filters=attribute_filters,
|
||||
attributes_to_return=attributes_to_save,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
rows = [
|
||||
{
|
||||
"trace_id": span.trace_id,
|
||||
"span_id": span.span_id,
|
||||
"parent_span_id": span.parent_span_id,
|
||||
"name": span.name,
|
||||
"start_time": span.start_time,
|
||||
"end_time": span.end_time,
|
||||
**{attr: span.attributes.get(attr) for attr in attributes_to_save},
|
||||
}
|
||||
for span in spans
|
||||
]
|
||||
|
||||
await self.datasetio_api.append_rows(dataset_id=dataset_id, rows=rows)
|
||||
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
attributes_to_return: list[str],
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpansResponse:
|
||||
traces = await self.query_traces(attribute_filters=attribute_filters)
|
||||
spans = []
|
||||
|
||||
for trace in traces.data:
|
||||
spans_by_id_resp = await self.get_span_tree(
|
||||
span_id=trace.root_span_id,
|
||||
attributes_to_return=attributes_to_return,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
for span in spans_by_id_resp.data.values():
|
||||
if span.attributes and all(
|
||||
attr in span.attributes and span.attributes[attr] is not None for attr in attributes_to_return
|
||||
):
|
||||
spans.append(
|
||||
Span(
|
||||
trace_id=trace.root_span_id,
|
||||
span_id=span.span_id,
|
||||
parent_span_id=span.parent_span_id,
|
||||
name=span.name,
|
||||
start_time=span.start_time,
|
||||
end_time=span.end_time,
|
||||
attributes=span.attributes,
|
||||
)
|
||||
)
|
||||
|
||||
return QuerySpansResponse(data=spans)
|
||||
|
|
@ -1,383 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Protocol
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
MetricDataPoint,
|
||||
MetricLabel,
|
||||
MetricLabelMatcher,
|
||||
MetricQueryType,
|
||||
MetricSeries,
|
||||
QueryCondition,
|
||||
QueryMetricsResponse,
|
||||
Span,
|
||||
SpanWithStatus,
|
||||
Trace,
|
||||
)
|
||||
|
||||
|
||||
class TraceStore(Protocol):
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> list[Trace]: ...
|
||||
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> dict[str, SpanWithStatus]: ...
|
||||
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime | None = None,
|
||||
granularity: str | None = "1d",
|
||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||
) -> QueryMetricsResponse: ...
|
||||
|
||||
|
||||
class SQLiteTraceStore(TraceStore):
|
||||
def __init__(self, conn_string: str):
|
||||
self.conn_string = conn_string
|
||||
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime | None = None,
|
||||
granularity: str | None = None,
|
||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||
) -> QueryMetricsResponse:
|
||||
if end_time is None:
|
||||
end_time = datetime.now(UTC)
|
||||
|
||||
# Build base query
|
||||
if query_type == MetricQueryType.INSTANT:
|
||||
query = """
|
||||
SELECT
|
||||
se.name,
|
||||
SUM(CAST(json_extract(se.attributes, '$.value') AS REAL)) as value,
|
||||
json_extract(se.attributes, '$.unit') as unit,
|
||||
se.attributes
|
||||
FROM span_events se
|
||||
WHERE se.name = ?
|
||||
AND se.timestamp BETWEEN ? AND ?
|
||||
"""
|
||||
else:
|
||||
if granularity:
|
||||
time_format = self._get_time_format_for_granularity(granularity)
|
||||
query = f"""
|
||||
SELECT
|
||||
se.name,
|
||||
SUM(CAST(json_extract(se.attributes, '$.value') AS REAL)) as value,
|
||||
json_extract(se.attributes, '$.unit') as unit,
|
||||
se.attributes,
|
||||
strftime('{time_format}', se.timestamp) as bucket_start
|
||||
FROM span_events se
|
||||
WHERE se.name = ?
|
||||
AND se.timestamp BETWEEN ? AND ?
|
||||
"""
|
||||
else:
|
||||
query = """
|
||||
SELECT
|
||||
se.name,
|
||||
json_extract(se.attributes, '$.value') as value,
|
||||
json_extract(se.attributes, '$.unit') as unit,
|
||||
se.attributes,
|
||||
se.timestamp
|
||||
FROM span_events se
|
||||
WHERE se.name = ?
|
||||
AND se.timestamp BETWEEN ? AND ?
|
||||
"""
|
||||
|
||||
params = [f"metric.{metric_name}", start_time.isoformat(), end_time.isoformat()]
|
||||
|
||||
# Labels that will be attached to the MetricSeries (preserve matcher labels)
|
||||
all_labels: list[MetricLabel] = []
|
||||
matcher_label_names = set()
|
||||
if label_matchers:
|
||||
for matcher in label_matchers:
|
||||
json_path = f"$.{matcher.name}"
|
||||
if matcher.operator == "=":
|
||||
query += f" AND json_extract(se.attributes, '{json_path}') = ?"
|
||||
params.append(matcher.value)
|
||||
elif matcher.operator == "!=":
|
||||
query += f" AND json_extract(se.attributes, '{json_path}') != ?"
|
||||
params.append(matcher.value)
|
||||
elif matcher.operator == "=~":
|
||||
query += f" AND json_extract(se.attributes, '{json_path}') LIKE ?"
|
||||
params.append(f"%{matcher.value}%")
|
||||
elif matcher.operator == "!~":
|
||||
query += f" AND json_extract(se.attributes, '{json_path}') NOT LIKE ?"
|
||||
params.append(f"%{matcher.value}%")
|
||||
# Preserve filter context in output
|
||||
all_labels.append(MetricLabel(name=matcher.name, value=str(matcher.value)))
|
||||
matcher_label_names.add(matcher.name)
|
||||
|
||||
# GROUP BY / ORDER BY logic
|
||||
if query_type == MetricQueryType.RANGE and granularity:
|
||||
group_time_format = self._get_time_format_for_granularity(granularity)
|
||||
query += f" GROUP BY strftime('{group_time_format}', se.timestamp), json_extract(se.attributes, '$.unit')"
|
||||
query += " ORDER BY bucket_start"
|
||||
elif query_type == MetricQueryType.INSTANT:
|
||||
query += " GROUP BY json_extract(se.attributes, '$.unit')"
|
||||
else:
|
||||
query += " ORDER BY se.timestamp"
|
||||
|
||||
# Execute query
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, params) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
return QueryMetricsResponse(data=[])
|
||||
|
||||
data_points = []
|
||||
# We want to add attribute labels, but only those not already present as matcher labels.
|
||||
attr_label_names = set()
|
||||
for row in rows:
|
||||
# Parse JSON attributes safely, if there are no attributes (weird), just don't add the labels to the result.
|
||||
try:
|
||||
attributes = json.loads(row["attributes"] or "{}")
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
attributes = {}
|
||||
|
||||
value = row["value"]
|
||||
unit = row["unit"] or ""
|
||||
|
||||
# Add labels from attributes without duplicating matcher labels, if we don't do this, there will be a lot of duplicate label in the result.
|
||||
for k, v in attributes.items():
|
||||
if k not in ["value", "unit"] and k not in matcher_label_names and k not in attr_label_names:
|
||||
all_labels.append(MetricLabel(name=k, value=str(v)))
|
||||
attr_label_names.add(k)
|
||||
|
||||
# Determine timestamp
|
||||
if query_type == MetricQueryType.RANGE and granularity:
|
||||
try:
|
||||
bucket_start_raw = row["bucket_start"]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"DB did not have a bucket_start time in row when using granularity, this indicates improper formatting"
|
||||
) from e
|
||||
# this value could also be there, but be NULL, I think.
|
||||
if bucket_start_raw is None:
|
||||
raise ValueError("bucket_start is None check time format and data")
|
||||
bucket_start = datetime.fromisoformat(bucket_start_raw)
|
||||
timestamp = int(bucket_start.timestamp())
|
||||
elif query_type == MetricQueryType.INSTANT:
|
||||
timestamp = int(datetime.now(UTC).timestamp())
|
||||
else:
|
||||
try:
|
||||
timestamp_raw = row["timestamp"]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"DB did not have a timestamp in row, this indicates improper formatting"
|
||||
) from e
|
||||
# this value could also be there, but be NULL, I think.
|
||||
if timestamp_raw is None:
|
||||
raise ValueError("timestamp is None check time format and data")
|
||||
timestamp_iso = datetime.fromisoformat(timestamp_raw)
|
||||
timestamp = int(timestamp_iso.timestamp())
|
||||
|
||||
data_points.append(
|
||||
MetricDataPoint(
|
||||
timestamp=timestamp,
|
||||
value=value,
|
||||
unit=unit,
|
||||
)
|
||||
)
|
||||
|
||||
metric_series = [MetricSeries(metric=metric_name, labels=all_labels, values=data_points)]
|
||||
return QueryMetricsResponse(data=metric_series)
|
||||
|
||||
def _get_time_format_for_granularity(self, granularity: str | None) -> str:
|
||||
"""Get the SQLite strftime format string for a given granularity.
|
||||
Args:
|
||||
granularity: Granularity string (e.g., "1m", "5m", "1h", "1d")
|
||||
Returns:
|
||||
SQLite strftime format string for the granularity
|
||||
"""
|
||||
if granularity is None:
|
||||
raise ValueError("granularity cannot be None for this method - use separate logic for no aggregation")
|
||||
|
||||
if granularity.endswith("d"):
|
||||
return "%Y-%m-%d 00:00:00"
|
||||
elif granularity.endswith("h"):
|
||||
return "%Y-%m-%d %H:00:00"
|
||||
elif granularity.endswith("m"):
|
||||
return "%Y-%m-%d %H:%M:00"
|
||||
else:
|
||||
return "%Y-%m-%d %H:%M:00" # Default to most granular which will give us the most timestamps.
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> list[Trace]:
|
||||
def build_where_clause() -> tuple[str, list]:
|
||||
if not attribute_filters:
|
||||
return "", []
|
||||
|
||||
ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"}
|
||||
|
||||
conditions = [
|
||||
f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op.value]} ?"
|
||||
for condition in attribute_filters
|
||||
]
|
||||
params = [condition.value for condition in attribute_filters]
|
||||
where_clause = " WHERE " + " AND ".join(conditions)
|
||||
return where_clause, params
|
||||
|
||||
def build_order_clause() -> str:
|
||||
if not order_by:
|
||||
return ""
|
||||
|
||||
order_clauses = []
|
||||
for field in order_by:
|
||||
desc = field.startswith("-")
|
||||
clean_field = field[1:] if desc else field
|
||||
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
|
||||
return " ORDER BY " + ", ".join(order_clauses)
|
||||
|
||||
# Build the main query
|
||||
base_query = """
|
||||
WITH matching_traces AS (
|
||||
SELECT DISTINCT t.trace_id
|
||||
FROM traces t
|
||||
JOIN spans s ON t.trace_id = s.trace_id
|
||||
{where_clause}
|
||||
),
|
||||
filtered_traces AS (
|
||||
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
|
||||
FROM matching_traces mt
|
||||
JOIN traces t ON mt.trace_id = t.trace_id
|
||||
LEFT JOIN spans s ON t.trace_id = s.trace_id
|
||||
{order_clause}
|
||||
)
|
||||
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
|
||||
FROM filtered_traces
|
||||
WHERE root_span_id IS NOT NULL
|
||||
LIMIT {limit} OFFSET {offset}
|
||||
"""
|
||||
|
||||
where_clause, params = build_where_clause()
|
||||
query = base_query.format(
|
||||
where_clause=where_clause,
|
||||
order_clause=build_order_clause(),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Execute query and return results
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, params) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
Trace(
|
||||
trace_id=row["trace_id"],
|
||||
root_span_id=row["root_span_id"],
|
||||
start_time=datetime.fromisoformat(row["start_time"]),
|
||||
end_time=datetime.fromisoformat(row["end_time"]),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> dict[str, SpanWithStatus]:
|
||||
# Build the attributes selection
|
||||
attributes_select = "s.attributes"
|
||||
if attributes_to_return:
|
||||
json_object = ", ".join(f"'{key}', json_extract(s.attributes, '$.{key}')" for key in attributes_to_return)
|
||||
attributes_select = f"json_object({json_object})"
|
||||
|
||||
# SQLite CTE query with filtered attributes
|
||||
query = f"""
|
||||
WITH RECURSIVE span_tree AS (
|
||||
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
|
||||
FROM spans s
|
||||
WHERE s.span_id = ?
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
|
||||
FROM spans s
|
||||
JOIN span_tree st ON s.parent_span_id = st.span_id
|
||||
WHERE (? IS NULL OR st.depth < ?)
|
||||
)
|
||||
SELECT *
|
||||
FROM span_tree
|
||||
ORDER BY depth, start_time
|
||||
"""
|
||||
|
||||
spans_by_id = {}
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
raise ValueError(f"Span {span_id} not found")
|
||||
|
||||
for row in rows:
|
||||
span = SpanWithStatus(
|
||||
span_id=row["span_id"],
|
||||
trace_id=row["trace_id"],
|
||||
parent_span_id=row["parent_span_id"],
|
||||
name=row["name"],
|
||||
start_time=datetime.fromisoformat(row["start_time"]),
|
||||
end_time=datetime.fromisoformat(row["end_time"]),
|
||||
attributes=json.loads(row["filtered_attributes"]),
|
||||
status=row["status"].lower(),
|
||||
)
|
||||
|
||||
spans_by_id[span.span_id] = span
|
||||
|
||||
return spans_by_id
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
query = """
|
||||
SELECT *
|
||||
FROM traces t
|
||||
WHERE t.trace_id = ?
|
||||
"""
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, (trace_id,)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
raise ValueError(f"Trace {trace_id} not found")
|
||||
return Trace(**row)
|
||||
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
query = "SELECT * FROM spans WHERE trace_id = ? AND span_id = ?"
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, (trace_id, span_id)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
raise ValueError(f"Span {span_id} not found")
|
||||
return Span(**row)
|
||||
|
|
@ -70,7 +70,7 @@ def trace_protocol[T](cls: type[T]) -> type[T]:
|
|||
"__class__": class_name,
|
||||
"__method__": method_name,
|
||||
"__type__": span_type,
|
||||
"__args__": str(combined_args),
|
||||
"__args__": json.dumps(combined_args),
|
||||
}
|
||||
|
||||
return class_name, method_name, span_attributes
|
||||
|
|
@ -82,8 +82,8 @@ def trace_protocol[T](cls: type[T]) -> type[T]:
|
|||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
count = 0
|
||||
try:
|
||||
count = 0
|
||||
async for item in method(self, *args, **kwargs):
|
||||
yield item
|
||||
count += 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue