mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-09 13:14:39 +00:00
feat: Adding OpenAI Compatible Prompts API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
30117dea22
commit
8b00883abd
181 changed files with 21356 additions and 10332 deletions
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import random
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -12,7 +13,7 @@ from chromadb import PersistentClient
|
|||
from pymilvus import MilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
|
@ -22,6 +23,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf
|
|||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
|
@ -29,7 +32,7 @@ COLLECTION_PREFIX = "test_collection"
|
|||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
@ -333,15 +336,127 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
|
|||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_psycopg2_connection():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
|
||||
cursor.__enter__ = MagicMock(return_value=cursor)
|
||||
cursor.__exit__ = MagicMock()
|
||||
|
||||
connection.cursor.return_value = cursor
|
||||
|
||||
return connection, cursor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id="pgvector",
|
||||
provider_resource_id="pgvector:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
index._test_chunks = []
|
||||
original_add_chunks = index.add_chunks
|
||||
|
||||
async def mock_add_chunks(chunks, embeddings):
|
||||
index._test_chunks = list(chunks)
|
||||
await original_add_chunks(chunks, embeddings)
|
||||
|
||||
index.add_chunks = mock_add_chunks
|
||||
|
||||
async def mock_query_vector(embedding, k, score_threshold):
|
||||
chunks = index._test_chunks[:k] if hasattr(index, "_test_chunks") else []
|
||||
scores = [1.0] * len(chunks)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
index.query_vector = mock_query_vector
|
||||
|
||||
yield index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
||||
config = PGVectorVectorIOConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
db="test_db",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
|
||||
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2.connect") as mock_connect:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_conn.autocommit = True
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
with patch(
|
||||
"llama_stack.providers.remote.vector_io.pgvector.pgvector.check_extension_version"
|
||||
) as mock_check_version:
|
||||
mock_check_version.return_value = "0.5.1"
|
||||
|
||||
with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl:
|
||||
mock_kvstore = AsyncMock()
|
||||
mock_kvstore_impl.return_value = mock_kvstore
|
||||
|
||||
with patch.object(adapter, "initialize_openai_vector_stores", new_callable=AsyncMock):
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.upsert_models"):
|
||||
await adapter.initialize()
|
||||
adapter.conn = mock_conn
|
||||
|
||||
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
adapter.insert_chunks = mock_insert_chunks
|
||||
|
||||
async def mock_query_chunks(vector_db_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
adapter.query_chunks = mock_query_chunks
|
||||
|
||||
test_vector_db = VectorDB(
|
||||
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
await adapter.register_vector_db(test_vector_db)
|
||||
adapter.test_collection_id = test_vector_db.identifier
|
||||
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
||||
vector_provider_dict = {
|
||||
"milvus": "milvus_vec_adapter",
|
||||
"faiss": "faiss_vec_adapter",
|
||||
"sqlite_vec": "sqlite_vec_adapter",
|
||||
"chroma": "chroma_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
"pgvector": "pgvector_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue