mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
sqlite-vec support for Responses file_search
This wires up the Files API optional dependency into sqlite_vec and adds the localfs Files provider to our starter template, so that Responses API file_search tool works out of the box for sqlite_vec in that template. Some additional testing with this provider plus some other inference models led me to loosen the verification test results checking a bit - not for the tool call, but just around the assistant response with the file_search tool call. Some providers, such as OpenAI SaaS, make multiple tool calls to resolve the query sometimes, especially when it cannot find an answer so tries a few permutations before returning empty results to the user in that test. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
ec09524a91
commit
7a71d9ebd8
8 changed files with 30 additions and 24 deletions
|
@ -15,6 +15,6 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
||||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||||
|
|
||||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference])
|
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
||||||
import sqlite_vec
|
import sqlite_vec
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.inference.inference import Inference
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
|
@ -24,7 +25,6 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
|
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||||
|
@ -302,9 +302,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, inference_api: Inference) -> None:
|
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
self.files_api = files_api
|
||||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
@ -490,15 +491,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
return await self.cache[vector_db_id].query_chunks(query, params)
|
return await self.cache[vector_db_id].query_chunks(query, params)
|
||||||
|
|
||||||
async def openai_attach_file_to_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
attributes: dict[str, Any] | None = None,
|
|
||||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
|
||||||
) -> VectorStoreFileObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores Files API is not supported in sqlite_vec")
|
|
||||||
|
|
||||||
|
|
||||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||||
|
|
|
@ -44,6 +44,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
|
@ -53,6 +54,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
|
|
|
@ -425,7 +425,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.files_api:
|
if not hasattr(self, "files_api") or not self.files_api:
|
||||||
vector_store_file_object.status = "failed"
|
vector_store_file_object.status = "failed"
|
||||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||||
code="server_error",
|
code="server_error",
|
||||||
|
|
|
@ -17,6 +17,8 @@ distribution_spec:
|
||||||
- inline::sqlite-vec
|
- inline::sqlite-vec
|
||||||
- remote::chromadb
|
- remote::chromadb
|
||||||
- remote::pgvector
|
- remote::pgvector
|
||||||
|
files:
|
||||||
|
- inline::localfs
|
||||||
safety:
|
safety:
|
||||||
- inline::llama-guard
|
- inline::llama-guard
|
||||||
agents:
|
agents:
|
||||||
|
|
|
@ -4,6 +4,7 @@ apis:
|
||||||
- agents
|
- agents
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
|
@ -75,6 +76,14 @@ providers:
|
||||||
db: ${env.PGVECTOR_DB:}
|
db: ${env.PGVECTOR_DB:}
|
||||||
user: ${env.PGVECTOR_USER:}
|
user: ${env.PGVECTOR_USER:}
|
||||||
password: ${env.PGVECTOR_PASSWORD:}
|
password: ${env.PGVECTOR_PASSWORD:}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/files_metadata.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.distribution.datatypes import (
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
|
@ -134,6 +135,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
|
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
|
||||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||||
|
"files": ["inline::localfs"],
|
||||||
"safety": ["inline::llama-guard"],
|
"safety": ["inline::llama-guard"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
"telemetry": ["inline::meta-reference"],
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
@ -170,6 +172,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
files_provider = Provider(
|
||||||
|
provider_id="meta-reference-files",
|
||||||
|
provider_type="inline::localfs",
|
||||||
|
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
)
|
||||||
embedding_provider = Provider(
|
embedding_provider = Provider(
|
||||||
provider_id="sentence-transformers",
|
provider_id="sentence-transformers",
|
||||||
provider_type="inline::sentence-transformers",
|
provider_type="inline::sentence-transformers",
|
||||||
|
@ -212,6 +219,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": inference_providers + [embedding_provider],
|
"inference": inference_providers + [embedding_provider],
|
||||||
"vector_io": vector_io_providers,
|
"vector_io": vector_io_providers,
|
||||||
|
"files": [files_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models + [embedding_model],
|
default_models=default_models + [embedding_model],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
|
|
|
@ -340,7 +340,7 @@ def test_response_non_streaming_file_search(
|
||||||
response = openai_client.responses.create(
|
response = openai_client.responses.create(
|
||||||
model=model,
|
model=model,
|
||||||
input=case["input"],
|
input=case["input"],
|
||||||
tools=case["tools"],
|
tools=tools,
|
||||||
stream=False,
|
stream=False,
|
||||||
include=["file_search_call.results"],
|
include=["file_search_call.results"],
|
||||||
)
|
)
|
||||||
|
@ -354,11 +354,7 @@ def test_response_non_streaming_file_search(
|
||||||
assert case["output"].lower() in response.output[0].results[0].text.lower()
|
assert case["output"].lower() in response.output[0].results[0].text.lower()
|
||||||
assert response.output[0].results[0].score > 0
|
assert response.output[0].results[0].score > 0
|
||||||
|
|
||||||
# Verify the assistant response that summarizes the results
|
# Verify the output_text generated by the response
|
||||||
assert response.output[1].type == "message"
|
|
||||||
assert response.output[1].status == "completed"
|
|
||||||
assert response.output[1].role == "assistant"
|
|
||||||
assert len(response.output[1].content) > 0
|
|
||||||
assert case["output"].lower() in response.output_text.lower().strip()
|
assert case["output"].lower() in response.output_text.lower().strip()
|
||||||
|
|
||||||
|
|
||||||
|
@ -390,11 +386,8 @@ def test_response_non_streaming_file_search_empty_vector_store(
|
||||||
assert response.output[0].queries # ensure it's some non-empty list
|
assert response.output[0].queries # ensure it's some non-empty list
|
||||||
assert not response.output[0].results # ensure we don't get any results
|
assert not response.output[0].results # ensure we don't get any results
|
||||||
|
|
||||||
# Verify the assistant response that summarizes the results
|
# Verify some output_text was generated by the response
|
||||||
assert response.output[1].type == "message"
|
assert response.output_text
|
||||||
assert response.output[1].status == "completed"
|
|
||||||
assert response.output[1].role == "assistant"
|
|
||||||
assert len(response.output[1].content) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue