llama-stack/llama_stack/providers/tests/vector_io/fixtures.py
Francisco Arceo 119fe8742a
feat: Adding sqlite-vec as a vectordb (#1040)
# What does this PR do?
This PR adds `sqlite_vec` as an additional inline vectordb.

Tested with `ollama` by adding the `vector_io` object in
`./llama_stack/templates/ollama/run.yaml` :

```yaml
  vector_io:
  - provider_id: sqlite_vec
    provider_type: inline::sqlite_vec
    config:
      kvstore:
        type: sqlite
        namespace: null
        db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
      db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db
```
I also updated the `./tests/client-sdk/vector_io/test_vector_io.py` test
file with:
```python
INLINE_VECTOR_DB_PROVIDERS = ["faiss", "sqlite_vec"]
```
And parameterized the relevant tests. 

[//]: # (If resolving an issue, uncomment and update the line below)
# Closes 
https://github.com/meta-llama/llama-stack/issues/1005

## Test Plan
I ran the tests with:
```bash
INFERENCE_MODEL=llama3.2:3b-instruct-fp16 LLAMA_STACK_CONFIG=ollama pytest -s -v tests/client-sdk/vector_io/test_vector_io.py
```
Which outputs:
```python
...
PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_retrieve[all-MiniLM-L6-v2-sqlite_vec] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_list PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-faiss] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-sqlite_vec] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[faiss] PASSED
tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[sqlite_vec] PASSED
```

In addition, I ran the `rag_with_vector_db.py`
[example](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/rag_with_vector_db.py)
using the script below with `uv run rag_example.py`.
<details>
<summary>CLICK TO SHOW SCRIPT 👋  </summary>

```python
#!/usr/bin/env python3
import os
import uuid
from termcolor import cprint

# Set environment variables
os.environ['INFERENCE_MODEL'] = 'llama3.2:3b-instruct-fp16'
os.environ['LLAMA_STACK_CONFIG'] = 'ollama'

# Import libraries after setting environment variables
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types import Document


def main():
    # Initialize the client
    client = LlamaStackAsLibraryClient("ollama")
    vector_db_id = f"test-vector-db-{uuid.uuid4().hex}"

    _ = client.initialize()

    model_id = 'llama3.2:3b-instruct-fp16'

    # Define the list of document URLs and create Document objects
    urls = [
        "chat.rst",
        "llama3.rst",
        "memory_optimizations.rst",
        "lora_finetune.rst",
    ]
    documents = [
        Document(
            document_id=f"num-{i}",
            content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
            mime_type="text/plain",
            metadata={},
        )
        for i, url in enumerate(urls)
    ]
    # (Optional) Use the documents as needed with your client here

    client.vector_dbs.register(
        provider_id='sqlite_vec',
        vector_db_id=vector_db_id,
        embedding_model="all-MiniLM-L6-v2",
        embedding_dimension=384,
    )

    client.tool_runtime.rag_tool.insert(
        documents=documents,
        vector_db_id=vector_db_id,
        chunk_size_in_tokens=512,
    )
    # Create agent configuration
    agent_config = AgentConfig(
        model=model_id,
        instructions="You are a helpful assistant",
        enable_session_persistence=False,
        toolgroups=[
            {
                "name": "builtin::rag",
                "args": {
                    "vector_db_ids": [vector_db_id],
                }
            }
        ],
    )

    # Instantiate the Agent
    agent = Agent(client, agent_config)

    # List of user prompts
    user_prompts = [
        "What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
        "Was anything related to 'Llama3' discussed, if so what?",
        "Tell me how to use LoRA",
        "What about Quantization?",
    ]

    # Create a session for the agent
    session_id = agent.create_session("test-session")

    # Process each prompt and display the output
    for prompt in user_prompts:
        cprint(f"User> {prompt}", "green")
        response = agent.create_turn(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            session_id=session_id,
        )
        # Log and print events from the response
        for log in EventLogger().log(response):
            log.print()


if __name__ == "__main__":
    main()
```
</details>

Which outputs a large summary of RAG generation.

# Documentation

Will handle documentation updates in follow-up PR.

# (- [ ] Added a Changelog entry if the change is significant)

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
2025-02-12 10:50:03 -08:00

167 lines
5.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig
from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def embedding_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--embedding-model", None)
@pytest.fixture(scope="session")
def vector_io_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def vector_io_faiss() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissImplConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_sqlite_vec() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite_vec",
config=SQLiteVectorIOConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_pgvector() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
user=get_env_or_fail("PGVECTOR_USER"),
password=get_env_or_fail("PGVECTOR_PASSWORD"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_weaviate() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateConfig().model_dump(),
)
],
provider_data=dict(
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
),
)
@pytest.fixture(scope="session")
def vector_io_chroma() -> ProviderFixture:
url = os.getenv("CHROMA_URL")
if url:
config = ChromaRemoteImplConfig(url=url)
provider_type = "remote::chromadb"
else:
if not os.getenv("CHROMA_DB_PATH"):
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH"))
provider_type = "inline::chromadb"
return ProviderFixture(
providers=[
Provider(
provider_id="chroma",
provider_type=provider_type,
config=config.model_dump(),
)
]
)
VECTOR_IO_FIXTURES = [
"faiss",
"pgvector",
"weaviate",
"chroma",
"sqlite_vec",
]
@pytest_asyncio.fixture(scope="session")
async def vector_io_stack(embedding_model, request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "vector_io"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.vector_io, Api.inference],
providers,
provider_data,
models=[
ModelInput(
model_id=embedding_model,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
)
],
)
return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs]