mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
Merge branch 'main' into add-mcp-authentication-param
This commit is contained in:
commit
607e3cc05c
44 changed files with 1899 additions and 464 deletions
|
|
@ -10,7 +10,7 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Body
|
||||
from fastapi import Body, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||
|
|
@ -224,10 +224,16 @@ class VectorStoreContent(BaseModel):
|
|||
|
||||
:param type: Content type, currently only "text" is supported
|
||||
:param text: The actual text content
|
||||
:param embedding: Optional embedding vector for this content chunk
|
||||
:param chunk_metadata: Optional chunk metadata
|
||||
:param metadata: Optional user-defined metadata
|
||||
"""
|
||||
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
embedding: list[float] | None = None
|
||||
chunk_metadata: ChunkMetadata | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -280,6 +286,22 @@ class VectorStoreDeleteResponse(BaseModel):
|
|||
deleted: bool = True
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileContentResponse(BaseModel):
|
||||
"""Represents the parsed content of a vector store file.
|
||||
|
||||
:param object: The object type, which is always `vector_store.file_content.page`
|
||||
:param data: Parsed content of the file
|
||||
:param has_more: Indicates if there are more content pages to fetch
|
||||
:param next_page: The token for the next page, if any
|
||||
"""
|
||||
|
||||
object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page"
|
||||
data: list[VectorStoreContent]
|
||||
has_more: bool = False
|
||||
next_page: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyAuto(BaseModel):
|
||||
"""Automatic chunking strategy for vector store files.
|
||||
|
|
@ -395,22 +417,6 @@ class VectorStoreListFilesResponse(BaseModel):
|
|||
has_more: bool = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileContentResponse(BaseModel):
|
||||
"""Represents the parsed content of a vector store file.
|
||||
|
||||
:param object: The object type, which is always `vector_store.file_content.page`
|
||||
:param data: Parsed content of the file
|
||||
:param has_more: Indicates if there are more content pages to fetch
|
||||
:param next_page: The token for the next page, if any
|
||||
"""
|
||||
|
||||
object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page"
|
||||
data: list[VectorStoreContent]
|
||||
has_more: bool
|
||||
next_page: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileDeleteResponse(BaseModel):
|
||||
"""Response from deleting a vector store file.
|
||||
|
|
@ -732,12 +738,16 @@ class VectorIO(Protocol):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
include_embeddings: Annotated[bool | None, Query(default=False)] = False,
|
||||
include_metadata: Annotated[bool | None, Query(default=False)] = False,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
"""Retrieves the contents of a vector store file.
|
||||
|
||||
:param vector_store_id: The ID of the vector store containing the file to retrieve.
|
||||
:param file_id: The ID of the file to retrieve.
|
||||
:returns: A VectorStoreFileContentResponse representing the file contents.
|
||||
:param include_embeddings: Whether to include embedding vectors in the response.
|
||||
:param include_metadata: Whether to include chunk metadata in the response.
|
||||
:returns: File contents, optionally with embeddings and metadata based on query parameters.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib.resources
|
||||
import sys
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -12,9 +11,6 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.core.datatypes import BuildConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.utils.exec import run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
@ -101,64 +97,3 @@ def print_pip_install_help(config: BuildConfig):
|
|||
for special_dep in special_deps:
|
||||
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
|
||||
print()
|
||||
|
||||
|
||||
def build_image(
|
||||
build_config: BuildConfig,
|
||||
image_name: str,
|
||||
distro_or_config: str,
|
||||
run_config: str | None = None,
|
||||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
|
||||
|
||||
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
if build_config.external_apis_dir:
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
for _, api_spec in external_apis.items():
|
||||
normal_deps.extend(api_spec.pip_packages)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
"--distro-or-config",
|
||||
distro_or_config,
|
||||
"--image-name",
|
||||
image_name,
|
||||
"--container-base",
|
||||
container_base,
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
# When building from a config file (not a template), include the run config path in the
|
||||
# build arguments
|
||||
if run_config is not None:
|
||||
args.extend(["--run-config", run_config])
|
||||
else:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
str(image_name),
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# Always pass both arguments, even if empty, to maintain consistent positional arguments
|
||||
if special_deps:
|
||||
args.extend(["--optional-deps", "#".join(special_deps)])
|
||||
if external_provider_deps:
|
||||
args.extend(
|
||||
["--external-provider-deps", "#".join(external_provider_deps)]
|
||||
) # the script will install external provider module, get its deps, and install those too.
|
||||
|
||||
return_code = run_command(args)
|
||||
|
||||
if return_code != 0:
|
||||
log.error(
|
||||
f"Failed to build target {image_name} with return code {return_code}",
|
||||
)
|
||||
|
||||
return return_code
|
||||
|
|
|
|||
|
|
@ -203,16 +203,11 @@ class ConversationServiceImpl(Conversations):
|
|||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
)
|
||||
await self.sql_store.upsert(
|
||||
table="conversation_items",
|
||||
data=item_record,
|
||||
conflict_columns=["id"],
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -389,6 +389,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
# Pass through params that aren't already handled as path params
|
||||
if options.params:
|
||||
extra_query_params = {k: v for k, v in options.params.items() if k not in path_params}
|
||||
if extra_query_params:
|
||||
body["extra_query"] = extra_query_params
|
||||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
|
|
|||
|
|
@ -247,6 +247,13 @@ class VectorIORouter(VectorIO):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
|
||||
# Check if provider_id is being changed (not supported)
|
||||
if metadata and "provider_id" in metadata:
|
||||
current_store = await self.routing_table.get_object_by_identifier("vector_store", vector_store_id)
|
||||
if current_store and current_store.provider_id != metadata["provider_id"]:
|
||||
raise ValueError("provider_id cannot be changed after vector store creation")
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -338,12 +345,19 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
include_embeddings: bool | None = False,
|
||||
include_metadata: bool | None = False,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
logger.debug(
|
||||
f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}, "
|
||||
f"include_embeddings={include_embeddings}, include_metadata={include_metadata}"
|
||||
)
|
||||
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
include_embeddings=include_embeddings,
|
||||
include_metadata=include_metadata,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
|
|
|
|||
|
|
@ -195,12 +195,17 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
include_embeddings: bool | None = False,
|
||||
include_metadata: bool | None = False,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
include_embeddings=include_embeddings,
|
||||
include_metadata=include_metadata,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,5 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
|
|||
def get_distribution_template() -> DistributionTemplate:
|
||||
template = get_starter_distribution_template(name="ci-tests")
|
||||
template.description = "CI tests for Llama Stack"
|
||||
template.run_configs.pop("run-with-postgres-store.yaml", None)
|
||||
|
||||
return template
|
||||
|
|
|
|||
|
|
@ -0,0 +1,293 @@
|
|||
version: 2
|
||||
image_name: ci-tests
|
||||
apis:
|
||||
- agents
|
||||
- batches
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
api_key: ${env.ANTHROPIC_API_KEY:=}
|
||||
- provider_id: gemini
|
||||
provider_type: remote::gemini
|
||||
config:
|
||||
api_key: ${env.GEMINI_API_KEY:=}
|
||||
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
|
||||
provider_type: remote::vertexai
|
||||
config:
|
||||
project: ${env.VERTEX_AI_PROJECT:=}
|
||||
location: ${env.VERTEX_AI_LOCATION:=us-central1}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
persistence:
|
||||
namespace: vector_io::faiss
|
||||
backend: kv_default
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
|
||||
persistence:
|
||||
namespace: vector_io::sqlite_vec
|
||||
backend: kv_default
|
||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
|
||||
persistence:
|
||||
namespace: vector_io::milvus
|
||||
backend: kv_default
|
||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
persistence:
|
||||
namespace: vector_io::chroma_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
host: ${env.PGVECTOR_HOST:=localhost}
|
||||
port: ${env.PGVECTOR_PORT:=5432}
|
||||
db: ${env.PGVECTOR_DB:=}
|
||||
user: ${env.PGVECTOR_USER:=}
|
||||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
persistence:
|
||||
namespace: vector_io::qdrant_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
persistence:
|
||||
namespace: vector_io::weaviate
|
||||
backend: kv_default
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
|
||||
metadata_store:
|
||||
table_name: files_metadata
|
||||
backend: sql_default
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
- provider_id: code-scanner
|
||||
provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
post_training:
|
||||
- provider_id: torchtune-cpu
|
||||
provider_type: inline::torchtune-cpu
|
||||
config:
|
||||
checkpoint_format: meta
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: eval
|
||||
backend: kv_default
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::huggingface
|
||||
backend: kv_default
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::localfs
|
||||
backend: kv_default
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: batches
|
||||
backend: kv_default
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
|
||||
sql_default:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
@ -165,20 +165,15 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
responses_store:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
post_training:
|
||||
- provider_id: huggingface-gpu
|
||||
provider_type: inline::huggingface-gpu
|
||||
|
|
@ -237,10 +232,10 @@ providers:
|
|||
config:
|
||||
kvstore:
|
||||
namespace: batches
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
storage:
|
||||
backends:
|
||||
kv_postgres:
|
||||
kv_default:
|
||||
type: kv_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
|
|
@ -248,7 +243,7 @@ storage:
|
|||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
|
||||
sql_postgres:
|
||||
sql_default:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
|
|
@ -258,27 +253,44 @@ storage:
|
|||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_postgres
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_postgres
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -165,20 +165,15 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
responses_store:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
post_training:
|
||||
- provider_id: torchtune-cpu
|
||||
provider_type: inline::torchtune-cpu
|
||||
|
|
@ -234,10 +229,10 @@ providers:
|
|||
config:
|
||||
kvstore:
|
||||
namespace: batches
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
storage:
|
||||
backends:
|
||||
kv_postgres:
|
||||
kv_default:
|
||||
type: kv_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
|
|
@ -245,7 +240,7 @@ storage:
|
|||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
|
||||
sql_postgres:
|
||||
sql_default:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
|
|
@ -255,27 +250,44 @@ storage:
|
|||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_postgres
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_postgres
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -17,11 +17,6 @@ from llama_stack.core.datatypes import (
|
|||
ToolGroupInput,
|
||||
VectorStoresConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
SqlStoreReference,
|
||||
)
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.datatypes import RemoteProviderSpec
|
||||
|
|
@ -154,10 +149,11 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::reference"),
|
||||
],
|
||||
}
|
||||
files_config = LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}")
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
config=files_config,
|
||||
)
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
|
|
@ -187,7 +183,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
|
||||
),
|
||||
]
|
||||
postgres_config = PostgresSqlStoreConfig.sample_run_config()
|
||||
postgres_sql_config = PostgresSqlStoreConfig.sample_run_config()
|
||||
postgres_kv_config = PostgresKVStoreConfig.sample_run_config()
|
||||
default_overrides = {
|
||||
"inference": remote_inference_providers + [embedding_provider],
|
||||
"vector_io": [
|
||||
|
|
@ -244,6 +241,33 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
"files": [files_provider],
|
||||
}
|
||||
|
||||
base_run_settings = RunConfigSettings(
|
||||
provider_overrides=default_overrides,
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=default_shields,
|
||||
vector_stores_config=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="sentence-transformers",
|
||||
model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
),
|
||||
),
|
||||
safety_config=SafetyConfig(
|
||||
default_shield_id="llama-guard",
|
||||
),
|
||||
)
|
||||
|
||||
postgres_run_settings = base_run_settings.model_copy(
|
||||
update={
|
||||
"storage_backends": {
|
||||
"kv_default": postgres_kv_config,
|
||||
"sql_default": postgres_sql_config,
|
||||
}
|
||||
},
|
||||
deep=True,
|
||||
)
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
|
|
@ -253,71 +277,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
providers=providers,
|
||||
additional_pip_packages=list(set(PostgresSqlStoreConfig.pip_packages() + PostgresKVStoreConfig.pip_packages())),
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides=default_overrides,
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=default_shields,
|
||||
vector_stores_config=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="sentence-transformers",
|
||||
model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
),
|
||||
),
|
||||
safety_config=SafetyConfig(
|
||||
default_shield_id="llama-guard",
|
||||
),
|
||||
),
|
||||
"run-with-postgres-store.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
**default_overrides,
|
||||
"agents": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
config=dict(
|
||||
persistence_store=postgres_config,
|
||||
responses_store=postgres_config,
|
||||
),
|
||||
)
|
||||
],
|
||||
"batches": [
|
||||
Provider(
|
||||
provider_id="reference",
|
||||
provider_type="inline::reference",
|
||||
config=dict(
|
||||
kvstore=KVStoreReference(
|
||||
backend="kv_postgres",
|
||||
namespace="batches",
|
||||
).model_dump(exclude_none=True),
|
||||
),
|
||||
)
|
||||
],
|
||||
},
|
||||
storage_backends={
|
||||
"kv_postgres": PostgresKVStoreConfig.sample_run_config(),
|
||||
"sql_postgres": postgres_config,
|
||||
},
|
||||
storage_stores={
|
||||
"metadata": KVStoreReference(
|
||||
backend="kv_postgres",
|
||||
namespace="registry",
|
||||
).model_dump(exclude_none=True),
|
||||
"inference": InferenceStoreReference(
|
||||
backend="sql_postgres",
|
||||
table_name="inference_store",
|
||||
).model_dump(exclude_none=True),
|
||||
"conversations": SqlStoreReference(
|
||||
backend="sql_postgres",
|
||||
table_name="openai_conversations",
|
||||
).model_dump(exclude_none=True),
|
||||
"prompts": KVStoreReference(
|
||||
backend="kv_postgres",
|
||||
namespace="prompts",
|
||||
).model_dump(exclude_none=True),
|
||||
},
|
||||
),
|
||||
"run.yaml": base_run_settings,
|
||||
"run-with-postgres-store.yaml": postgres_run_settings,
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
|
|
|
|||
|
|
@ -66,14 +66,6 @@ class InferenceStore:
|
|||
},
|
||||
)
|
||||
|
||||
if self.enable_write_queue:
|
||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
logger.debug(
|
||||
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
|
|
@ -94,10 +86,29 @@ class InferenceStore:
|
|||
if self.enable_write_queue and self._queue is not None:
|
||||
await self._queue.join()
|
||||
|
||||
async def _ensure_workers_started(self) -> None:
|
||||
"""Ensure the async write queue workers run on the current loop."""
|
||||
if not self.enable_write_queue:
|
||||
return
|
||||
|
||||
if self._queue is None:
|
||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
logger.debug(
|
||||
f"Inference store write queue created with max size {self._max_write_queue_size} "
|
||||
f"and {self._num_writers} writers"
|
||||
)
|
||||
|
||||
if not self._worker_tasks:
|
||||
loop = asyncio.get_running_loop()
|
||||
for _ in range(self._num_writers):
|
||||
task = loop.create_task(self._worker_loop())
|
||||
self._worker_tasks.append(task)
|
||||
|
||||
async def store_chat_completion(
|
||||
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||
) -> None:
|
||||
if self.enable_write_queue:
|
||||
await self._ensure_workers_started()
|
||||
if self._queue is None:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
|
||||
|
||||
from .api import KVStore
|
||||
|
|
@ -53,45 +56,63 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
|
||||
|
||||
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
|
||||
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
|
||||
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
|
||||
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available KV store backends for reference resolution."""
|
||||
global _KVSTORE_BACKENDS
|
||||
global _KVSTORE_INSTANCES
|
||||
global _KVSTORE_LOCKS
|
||||
|
||||
_KVSTORE_BACKENDS.clear()
|
||||
_KVSTORE_INSTANCES.clear()
|
||||
_KVSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_KVSTORE_BACKENDS[name] = cfg
|
||||
|
||||
|
||||
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
||||
backend_name = reference.backend
|
||||
cache_key = (backend_name, reference.namespace)
|
||||
|
||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
backend_config = _KVSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
|
||||
|
||||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
lock = _KVSTORE_LOCKS[cache_key]
|
||||
async with lock:
|
||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if config.type == StorageBackendType.KV_REDIS.value:
|
||||
from .redis import RedisKVStoreImpl
|
||||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_SQLITE.value:
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
if config.type == StorageBackendType.KV_REDIS.value:
|
||||
from .redis import RedisKVStoreImpl
|
||||
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_POSTGRES.value:
|
||||
from .postgres import PostgresKVStoreImpl
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_SQLITE.value:
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_MONGODB.value:
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_POSTGRES.value:
|
||||
from .postgres import PostgresKVStoreImpl
|
||||
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_MONGODB.value:
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
await impl.initialize()
|
||||
return impl
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
|
||||
await impl.initialize()
|
||||
_KVSTORE_INSTANCES[cache_key] = impl
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -704,34 +704,35 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
# Unknown filter type, default to no match
|
||||
raise ValueError(f"Unsupported filter type: {filter_type}")
|
||||
|
||||
def _chunk_to_vector_store_content(self, chunk: Chunk) -> list[VectorStoreContent]:
|
||||
# content is InterleavedContent
|
||||
def _chunk_to_vector_store_content(
|
||||
self, chunk: Chunk, include_embeddings: bool = False, include_metadata: bool = False
|
||||
) -> list[VectorStoreContent]:
|
||||
def extract_fields() -> dict:
|
||||
"""Extract embedding and metadata fields from chunk based on include flags."""
|
||||
return {
|
||||
"embedding": chunk.embedding if include_embeddings else None,
|
||||
"chunk_metadata": chunk.chunk_metadata if include_metadata else None,
|
||||
"metadata": chunk.metadata if include_metadata else None,
|
||||
}
|
||||
|
||||
fields = extract_fields()
|
||||
|
||||
if isinstance(chunk.content, str):
|
||||
content = [
|
||||
VectorStoreContent(
|
||||
type="text",
|
||||
text=chunk.content,
|
||||
)
|
||||
]
|
||||
content_item = VectorStoreContent(type="text", text=chunk.content, **fields)
|
||||
content = [content_item]
|
||||
elif isinstance(chunk.content, list):
|
||||
# TODO: Add support for other types of content
|
||||
content = [
|
||||
VectorStoreContent(
|
||||
type="text",
|
||||
text=item.text,
|
||||
)
|
||||
for item in chunk.content
|
||||
if item.type == "text"
|
||||
]
|
||||
content = []
|
||||
for item in chunk.content:
|
||||
if item.type == "text":
|
||||
content_item = VectorStoreContent(type="text", text=item.text, **fields)
|
||||
content.append(content_item)
|
||||
else:
|
||||
if chunk.content.type != "text":
|
||||
raise ValueError(f"Unsupported content type: {chunk.content.type}")
|
||||
content = [
|
||||
VectorStoreContent(
|
||||
type="text",
|
||||
text=chunk.content.text,
|
||||
)
|
||||
]
|
||||
|
||||
content_item = VectorStoreContent(type="text", text=chunk.content.text, **fields)
|
||||
content = [content_item]
|
||||
return content
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
|
|
@ -820,13 +821,12 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
message=str(e),
|
||||
)
|
||||
|
||||
# Create OpenAI vector store file metadata
|
||||
# Save vector store file to persistent storage AFTER insert_chunks
|
||||
# so that chunks include the embeddings that were generated
|
||||
file_info = vector_store_file_object.model_dump(exclude={"last_error"})
|
||||
file_info["filename"] = file_response.filename if file_response else ""
|
||||
|
||||
# Save vector store file to persistent storage (provider-specific)
|
||||
dict_chunks = [c.model_dump() for c in chunks]
|
||||
# This should be updated to include chunk_id
|
||||
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks)
|
||||
|
||||
# Update file_ids and file_counts in vector store metadata
|
||||
|
|
@ -921,21 +921,27 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
include_embeddings: bool | None = False,
|
||||
include_metadata: bool | None = False,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
"""Retrieves the contents of a vector store file."""
|
||||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
# Parameters are already provided directly
|
||||
# include_embeddings and include_metadata are now function parameters
|
||||
|
||||
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
|
||||
chunks = [Chunk.model_validate(c) for c in dict_chunks]
|
||||
content = []
|
||||
for chunk in chunks:
|
||||
content.extend(self._chunk_to_vector_store_content(chunk))
|
||||
content.extend(
|
||||
self._chunk_to_vector_store_content(
|
||||
chunk, include_embeddings=include_embeddings or False, include_metadata=include_metadata or False
|
||||
)
|
||||
)
|
||||
return VectorStoreFileContentResponse(
|
||||
object="vector_store.file_content.page",
|
||||
data=content,
|
||||
has_more=False,
|
||||
next_page=None,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Order,
|
||||
|
|
@ -19,12 +17,12 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference, StorageBackendType
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
|
||||
from ..sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
|
@ -55,28 +53,12 @@ class ResponsesStore:
|
|||
|
||||
self.policy = policy
|
||||
self.sql_store = None
|
||||
self.enable_write_queue = True
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: (
|
||||
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
|
||||
) = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = self.reference.max_write_queue_size
|
||||
self._num_writers: int = max(1, self.reference.num_writers)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||
# Keep it enabled for other backends (like Postgres) for performance
|
||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
self.enable_write_queue = False
|
||||
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_responses",
|
||||
{
|
||||
|
|
@ -95,33 +77,12 @@ class ResponsesStore:
|
|||
},
|
||||
)
|
||||
|
||||
if self.enable_write_queue:
|
||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
logger.debug(
|
||||
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
return
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait for all queued writes to complete. Useful for testing."""
|
||||
if self.enable_write_queue and self._queue is not None:
|
||||
await self._queue.join()
|
||||
"""Maintained for compatibility; no-op now that writes are synchronous."""
|
||||
return
|
||||
|
||||
async def store_response_object(
|
||||
self,
|
||||
|
|
@ -129,31 +90,7 @@ class ResponsesStore:
|
|||
input: list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
if self.enable_write_queue:
|
||||
if self._queue is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
try:
|
||||
self._queue.put_nowait((response_object, input, messages))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
||||
await self._queue.put((response_object, input, messages))
|
||||
else:
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
assert self._queue is not None
|
||||
while True:
|
||||
try:
|
||||
item = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
response_object, input, messages = item
|
||||
try:
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error writing response object: {e}")
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
|
||||
async def _write_response_object(
|
||||
self,
|
||||
|
|
@ -315,19 +252,12 @@ class ResponsesStore:
|
|||
# Serialize messages to dict format for JSON storage
|
||||
messages_data = [msg.model_dump() for msg in messages]
|
||||
|
||||
# Upsert: try insert first, update if exists
|
||||
try:
|
||||
await self.sql_store.insert(
|
||||
table="conversation_messages",
|
||||
data={"conversation_id": conversation_id, "messages": messages_data},
|
||||
)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_messages",
|
||||
data={"messages": messages_data},
|
||||
where={"conversation_id": conversation_id},
|
||||
)
|
||||
await self.sql_store.upsert(
|
||||
table="conversation_messages",
|
||||
data={"conversation_id": conversation_id, "messages": messages_data},
|
||||
conflict_columns=["conversation_id"],
|
||||
update_columns=["messages"],
|
||||
)
|
||||
|
||||
logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,18 @@ class SqlStore(Protocol):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert a row and update specified columns when conflicts occur.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
|
|||
|
|
@ -45,8 +45,13 @@ def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: Use
|
|||
enhanced["owner_principal"] = current_user.principal
|
||||
enhanced["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced["owner_principal"] = None
|
||||
enhanced["access_attributes"] = None
|
||||
# IMPORTANT: Use empty string and null value (not None) to match public access filter
|
||||
# The public access filter in _get_public_access_conditions() expects:
|
||||
# - owner_principal = '' (empty string)
|
||||
# - access_attributes = null (JSON null, which serializes to the string 'null')
|
||||
# Setting them to None (SQL NULL) will cause rows to be filtered out on read.
|
||||
enhanced["owner_principal"] = ""
|
||||
enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null
|
||||
return enhanced
|
||||
|
||||
|
||||
|
|
@ -124,6 +129,23 @@ class AuthorizedSqlStore:
|
|||
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Upsert a row with automatic access control attribute capture."""
|
||||
current_user = get_authenticated_user()
|
||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||
await self.sql_store.upsert(
|
||||
table=table,
|
||||
data=enhanced_data,
|
||||
conflict_columns=conflict_columns,
|
||||
update_columns=update_columns,
|
||||
)
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
@ -188,8 +210,9 @@ class AuthorizedSqlStore:
|
|||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced_data["owner_principal"] = None
|
||||
enhanced_data["access_attributes"] = None
|
||||
# IMPORTANT: Use empty string for owner_principal to match public access filter
|
||||
enhanced_data["owner_principal"] = ""
|
||||
enhanced_data["access_attributes"] = None # Will serialize as JSON null
|
||||
|
||||
await self.sql_store.update(table, enhanced_data, where)
|
||||
|
||||
|
|
@ -245,14 +268,24 @@ class AuthorizedSqlStore:
|
|||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _get_public_access_conditions(self) -> list[str]:
|
||||
"""Get the SQL conditions for public access."""
|
||||
# Public records are records that have no owner_principal or access_attributes
|
||||
"""Get the SQL conditions for public access.
|
||||
|
||||
Public records are those with:
|
||||
- owner_principal = '' (empty string)
|
||||
- access_attributes is either SQL NULL or JSON null
|
||||
|
||||
Note: Different databases serialize None differently:
|
||||
- SQLite: None → JSON null (text = 'null')
|
||||
- Postgres: None → SQL NULL (IS NULL)
|
||||
"""
|
||||
conditions = ["owner_principal = ''"]
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
# Postgres stores JSON null as 'null'
|
||||
conditions.append("access_attributes::text = 'null'")
|
||||
# Accept both SQL NULL and JSON null for Postgres compatibility
|
||||
# This handles both old rows (SQL NULL) and new rows (JSON null)
|
||||
conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')")
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
conditions.append("access_attributes = 'null'")
|
||||
# SQLite serializes None as JSON null
|
||||
conditions.append("(access_attributes IS NULL OR access_attributes = 'null')")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
return conditions
|
||||
|
|
|
|||
|
|
@ -72,13 +72,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
|||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
self._is_sqlite_backend = "sqlite" in self.config.engine_str
|
||||
self.async_session = async_sessionmaker(self.create_engine())
|
||||
self.metadata = MetaData()
|
||||
|
||||
def create_engine(self) -> AsyncEngine:
|
||||
# Configure connection args for better concurrency support
|
||||
connect_args = {}
|
||||
if "sqlite" in self.config.engine_str:
|
||||
if self._is_sqlite_backend:
|
||||
# SQLite-specific optimizations for concurrent access
|
||||
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||
connect_args["timeout"] = 5.0
|
||||
|
|
@ -91,7 +92,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
)
|
||||
|
||||
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||
if "sqlite" in self.config.engine_str:
|
||||
if self._is_sqlite_backend:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
|
|
@ -151,6 +152,29 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
table_obj = self.metadata.tables[table]
|
||||
dialect_insert = self._get_dialect_insert(table_obj)
|
||||
insert_stmt = dialect_insert.values(**data)
|
||||
|
||||
if update_columns is None:
|
||||
update_columns = [col for col in data.keys() if col not in conflict_columns]
|
||||
|
||||
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
|
||||
conflict_cols = [table_obj.c[col] for col in conflict_columns]
|
||||
|
||||
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
|
||||
|
||||
async with self.async_session() as session:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
@ -333,9 +357,18 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
await conn.execute(add_column_sql)
|
||||
|
||||
except Exception as e:
|
||||
# If any error occurs during migration, log it but don't fail
|
||||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
||||
def _get_dialect_insert(self, table: Table):
|
||||
if self._is_sqlite_backend:
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
return sqlite_insert(table)
|
||||
else:
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
return pg_insert(table)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from threading import Lock
|
||||
from typing import Annotated, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
|
@ -21,6 +22,8 @@ from .api import SqlStore
|
|||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
|
||||
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
|
||||
_SQLSTORE_LOCKS: dict[str, Lock] = {}
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
|
|
@ -52,19 +55,34 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
|
|||
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
|
||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||
return SqlAlchemySqlStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
|
||||
with lock:
|
||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||
instance = SqlAlchemySqlStoreImpl(config)
|
||||
_SQLSTORE_INSTANCES[backend_name] = instance
|
||||
return instance
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||
|
||||
|
||||
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available SQL store backends for reference resolution."""
|
||||
global _SQLSTORE_BACKENDS
|
||||
global _SQLSTORE_INSTANCES
|
||||
|
||||
_SQLSTORE_BACKENDS.clear()
|
||||
_SQLSTORE_INSTANCES.clear()
|
||||
_SQLSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_SQLSTORE_BACKENDS[name] = cfg
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue