mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-17 02:18:13 +00:00
# What does this PR do? Resolves https://github.com/meta-llama/llama-stack/issues/2735 Currently, if you test against OpenAI's Vector Stores API the `client.vector_stores.search` call fails with an invalid vector_db during routing (see the script referenced in the clickable item under the Test Plan section). This PR ensures that `client.vector_stores.search()` is compatible with OpenAI's Vector Stores API. Two biggest changes: 1. The `name`, which was previously used as the `vector_db_id`, has been changed to be consistent with OpenAI's `vs_{uuid}` format. 2. The vector store ID has to be referenced by the ID, the name is not reliable as every `client.vector_stores.create` results in a new vector store. NOTE: I believe this is a breaking change for end users as they'll need to update their VectorDB identifiers. ## Test Plan Unit tests: ```bash ./scripts/unit-tests.sh tests/unit/providers/vector_io/ -v ``` Integration tests: ```bash ENABLE_MILVUS=milvus llama stack run /Users/farceo/dev/llama-stack/llama_stack/templates/starter/run.yaml --image-type venv LLAMA_STACK_CONFIG=http://localhost:8321 pytest -sv tests/integration/vector_io/test_openai_vector_stores.py --embedding-model=all-MiniLM-L6-v2 -vv ``` Unit tests and test script below 👇 <details> <summary>Click here for script used to test OpenAI and Llama Stack Vector Store implementation</summary> ```python import json import argparse from openai import OpenAI, pagination import logging from colorama import Fore, Style, init import traceback import os # Initialize colorama for color support in terminal init(autoreset=True) # Setup basic logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') DEMO_VECTOR_STORE_NAME = "Support FAQ FJA" global DEMO_VECTOR_STORE_ID global DEMO_VECTOR_STORE_ID2 def colored_print(color, text): """Prints text to the console with the specified color.""" print(f"{color}{text}{Style.RESET_ALL}") def log_and_print(color, message, level=logging.INFO): """Logs a message and prints it to the console with the specified color.""" logging.log(level, message) colored_print(color, message) def run_tests(client, prefix="openai"): """ Runs all tests using the provided OpenAI client and saves the output to JSON files with the given prefix. """ # Create the directory if it doesn't exist os.makedirs('openai_testing', exist_ok=True) # Default values in case tests fail global DEMO_VECTOR_STORE_ID, DEMO_VECTOR_STORE_ID2 DEMO_VECTOR_STORE_ID = None DEMO_VECTOR_STORE_ID2 = None def test_idempotent_vector_store_creation(): """ Test that creating a vector store with the same name is idempotent. """ log_and_print(Fore.BLUE, "Starting vector store creation test...") try: vector_store = client.vector_stores.create( name=DEMO_VECTOR_STORE_NAME, ) # Attempt to create the same vector store again vector_store2 = client.vector_stores.create( name=DEMO_VECTOR_STORE_NAME, ) # Check instead of assert if vector_store2.id != vector_store.id: log_and_print(Fore.YELLOW, f"FAILED IDEMPOTENCY: the same VectorStore name for {prefix.upper()} does not return the same ID", level=logging.WARNING) else: log_and_print(Fore.GREEN, f"PASSED IDEMPOTENCY: f{vector_store2.id} == {vector_store.id} the same VectorStore name for {prefix.upper()} returns the same ID") vector_store_data = vector_store.to_dict() log_and_print(Fore.WHITE, f"vector_stores.create = {json.dumps(vector_store_data, indent=2)}") with open(f'openai_testing/{prefix}_vector_store_create.json', 'w') as f: json.dump(vector_store_data, f, indent=2) global DEMO_VECTOR_STORE_ID, DEMO_VECTOR_STORE_ID2 DEMO_VECTOR_STORE_ID = vector_store.id DEMO_VECTOR_STORE_ID2 = vector_store2.id return DEMO_VECTOR_STORE_ID, DEMO_VECTOR_STORE_ID2 except Exception as e: log_and_print(Fore.RED, f"Idempotent vector store creation test failed: {e}", level=logging.ERROR) logging.error(traceback.format_exc()) # Create a fallback vector store ID if needed if 'vector_store' in locals() and vector_store: DEMO_VECTOR_STORE_ID = vector_store.id return DEMO_VECTOR_STORE_ID, DEMO_VECTOR_STORE_ID2 def test_vector_store_list(): """ Test listing vector stores. """ log_and_print(Fore.BLUE, "Starting vector store list test...") try: vector_stores = client.vector_stores.list() # Check instead of assert if not isinstance(vector_stores, pagination.SyncCursorPage): log_and_print(Fore.YELLOW, f"FAILED: Expected a list of vector stores, got {type(vector_stores)}", level=logging.WARNING) else: log_and_print(Fore.GREEN, "Vector store list test passed!") vector_stores_data = vector_stores.to_dict() log_and_print(Fore.WHITE, f"vector_stores.list = {json.dumps(vector_stores_data, indent=2)}") with open(f'openai_testing/{prefix}_vector_store_list.json', 'w') as f: json.dump(vector_stores_data, f, indent=2) except Exception as e: log_and_print(Fore.RED, f"Vector store list test failed: {e}", level=logging.ERROR) logging.error(traceback.format_exc()) def test_retrieve_vector_store(): """ Test retrieving a specific vector store. """ log_and_print(Fore.BLUE, "Starting retrieve vector store test...") if not DEMO_VECTOR_STORE_ID: log_and_print(Fore.YELLOW, "Skipping retrieve vector store test - no vector store ID available", level=logging.WARNING) return try: vector_store = client.vector_stores.retrieve( vector_store_id=DEMO_VECTOR_STORE_ID, ) # Check instead of assert if vector_store.id != DEMO_VECTOR_STORE_ID: log_and_print(Fore.YELLOW, "FAILED: Retrieved vector store ID does not match", level=logging.WARNING) else: log_and_print(Fore.GREEN, "Retrieve vector store test passed!") vector_store_data = vector_store.to_dict() log_and_print(Fore.WHITE, f"vector_stores.retrieve = {json.dumps(vector_store_data, indent=2)}") with open(f'openai_testing/{prefix}_vector_store_retrieve.json', 'w') as f: json.dump(vector_store_data, f, indent=2) except Exception as e: log_and_print(Fore.RED, f"Retrieve vector store test failed: {e}", level=logging.ERROR) logging.error(traceback.format_exc()) def test_modify_vector_store(): """ Test modifying a vector store. """ log_and_print(Fore.BLUE, "Starting modify vector store test...") if not DEMO_VECTOR_STORE_ID: log_and_print(Fore.YELLOW, "Skipping modify vector store test - no vector store ID available", level=logging.WARNING) return try: updated_vector_store = client.vector_stores.update( vector_store_id=DEMO_VECTOR_STORE_ID, name="Updated Support FAQ FJA", ) # Check instead of assert if updated_vector_store.name != "Updated Support FAQ FJA": log_and_print(Fore.YELLOW, "FAILED: Vector store name was not updated correctly", level=logging.WARNING) else: log_and_print(Fore.GREEN, "Modify vector store test passed!") updated_vector_store_data = updated_vector_store.to_dict() log_and_print(Fore.WHITE, f"vector_stores.modify = {json.dumps(updated_vector_store_data, indent=2)}") with open(f'openai_testing/{prefix}_vector_store_modify.json', 'w') as f: json.dump(updated_vector_store_data, f, indent=2) except Exception as e: log_and_print(Fore.RED, f"Modify vector store test failed: {e}", level=logging.ERROR) logging.error(traceback.format_exc()) def test_delete_vector_store(): """ Test deleting a vector store. """ log_and_print(Fore.BLUE, "Starting delete vector store test...") if not DEMO_VECTOR_STORE_ID2: log_and_print(Fore.YELLOW, "Skipping delete vector store test - no second vector store ID available", level=logging.WARNING) return try: response = client.vector_stores.delete( vector_store_id=DEMO_VECTOR_STORE_ID2, ) log_and_print(Fore.GREEN, "Delete vector store test passed!") response_data = response.to_dict() log_and_print(Fore.WHITE, f"Vector store delete response = {json.dumps(response_data, indent=2)}") with open(f'openai_testing/{prefix}_vector_store_delete.json', 'w') as f: json.dump(response_data, f, indent=2) except Exception as e: log_and_print(Fore.RED, f"Delete vector store test failed: {e}", level=logging.ERROR) logging.error(traceback.format_exc()) def test_create_vector_store_file(): log_and_print(Fore.BLUE, "Starting create vector store file test...") if not DEMO_VECTOR_STORE_ID: log_and_print(Fore.YELLOW, "Skipping create vector store file test - no vector store ID available", level=logging.WARNING) return try: # create jsonl of files as an example with open("mydata.jsonl", "w") as f: f.write('{"text": "What is the return policy?", "metadata": {"category": "support"}}\n') f.write('{"text": "How do I reset my password?", "metadata": {"category": "support"}}\n') f.write('{"text": "Where can I find my order history?", "metadata": {"category": "support"}}\n') f.write('{"text": "What are the shipping options?", "metadata": {"category": "support"}}\n') f.write('{"text": "What is your favorite banana?", "metadata": {"category": "support"}}\n') # Create a simple text file if my_data_small.txt doesn't exist if not os.path.exists("my_data_small.txt"): with open("my_data_small.txt", "w") as f: f.write("This is a test file for vector store testing.\n") created_file = client.files.create( file=open("my_data_small.txt", "rb"), purpose="assistants", ) created_file_data = created_file.to_dict() log_and_print(Fore.WHITE, f"Created file {json.dumps(created_file_data, indent=2)}") with open(f'openai_testing/{prefix}_file_create.json', 'w') as f: json.dump(created_file_data, f, indent=2) retrieved_files = client.files.retrieve(created_file.id) retrieved_files_data = retrieved_files.to_dict() log_and_print(Fore.WHITE, f"Retrieved file {json.dumps(retrieved_files_data, indent=2)}") with open(f'openai_testing/{prefix}_file_retrieve.json', 'w') as f: json.dump(retrieved_files_data, f, indent=2) vector_store_file = client.vector_stores.files.create( vector_store_id=DEMO_VECTOR_STORE_ID, file_id=created_file.id, ) log_and_print(Fore.GREEN, "Create vector store file test passed!") except Exception as e: log_and_print(Fore.RED, f"Create vector store file test failed: {e}", level=logging.ERROR) logging.error(traceback.format_exc()) def test_search_vector_store(): """ Test searching a vector store. """ log_and_print(Fore.BLUE, "Starting search vector store test...") if not DEMO_VECTOR_STORE_ID: log_and_print(Fore.YELLOW, "Skipping search vector store test - no vector store ID available", level=logging.WARNING) return try: query = "What is the banana policy?" search_results = client.vector_stores.search( vector_store_id=DEMO_VECTOR_STORE_ID, query=query, max_num_results=10, ranking_options={ 'ranker': 'default-2024-11-15', 'score_threshold': 0.0, }, rewrite_query=False, ) # Check instead of assert if not isinstance(search_results, pagination.SyncPage): log_and_print(Fore.YELLOW, f"FAILED: Expected a list of search results, got {type(search_results)}", level=logging.WARNING) else: log_and_print(Fore.GREEN, "Search vector store test passed!") search_results_dict = search_results.to_dict() log_and_print(Fore.WHITE, f"Search results = {search_results_dict}") with open(f'openai_testing/{prefix}_vector_store_search.json', 'w') as f: json.dump(search_results_dict, f, indent=2) log_and_print(Fore.WHITE, f"vector_stores.search = {search_results.to_json()}") except Exception as e: log_and_print(Fore.RED, f"Search vector store test failed: {e}", level=logging.ERROR) logging.error(traceback.format_exc()) # Run all tests in sequence, even if some fail test_results = [] try: result = test_idempotent_vector_store_creation() if result and len(result) == 2: DEMO_VECTOR_STORE_ID, DEMO_VECTOR_STORE_ID2 = result test_results.append(True) except Exception as e: log_and_print(Fore.RED, f"Vector store creation test failed: {e}", level=logging.ERROR) logging.error(traceback.format_exc()) test_results.append(False) for test_func in [ test_vector_store_list, test_retrieve_vector_store, test_modify_vector_store, test_delete_vector_store, test_create_vector_store_file, test_search_vector_store ]: try: test_func() test_results.append(True) except Exception as e: log_and_print(Fore.RED, f"{test_func.__name__} failed: {e}", level=logging.ERROR) logging.error(traceback.format_exc()) test_results.append(False) if all(test_results): log_and_print(Fore.GREEN, f"All {prefix} tests completed successfully!") else: failed_count = test_results.count(False) log_and_print(Fore.YELLOW, f"{failed_count} {prefix} test(s) failed, but script completed.") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run OpenAI and/or LlamaStack tests.") parser.add_argument( "--provider", type=str, default="llama", choices=["openai", "llama", "both"], help="Specify which environment to test: openai, llama, or both. Default is both.", ) args = parser.parse_args() try: if args.provider in ("openai", "both"): openai_client = OpenAI() run_tests(openai_client, prefix="openai") if args.provider in ("llama", "both"): llama_client = OpenAI(base_url="http://localhost:8321/v1/openai/v1", api_key="none") run_tests(llama_client, prefix="llama") log_and_print(Fore.GREEN, "All tests completed!") except Exception as e: log_and_print(Fore.RED, f"Tests failed to complete: {e}", level=logging.ERROR) logging.error(traceback.format_exc()) ``` </details> --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
386 lines
15 KiB
Python
386 lines
15 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
)
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.apis.vector_io import (
|
|
Chunk,
|
|
QueryChunksResponse,
|
|
SearchRankingOptions,
|
|
VectorIO,
|
|
VectorStoreChunkingStrategy,
|
|
VectorStoreDeleteResponse,
|
|
VectorStoreFileContentsResponse,
|
|
VectorStoreFileDeleteResponse,
|
|
VectorStoreFileObject,
|
|
VectorStoreFileStatus,
|
|
VectorStoreListResponse,
|
|
VectorStoreObject,
|
|
VectorStoreSearchResponsePage,
|
|
)
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class VectorIORouter(VectorIO):
|
|
"""Routes to an provider based on the vector db identifier"""
|
|
|
|
def __init__(
|
|
self,
|
|
routing_table: RoutingTable,
|
|
) -> None:
|
|
logger.debug("Initializing VectorIORouter")
|
|
self.routing_table = routing_table
|
|
|
|
async def initialize(self) -> None:
|
|
logger.debug("VectorIORouter.initialize")
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.debug("VectorIORouter.shutdown")
|
|
pass
|
|
|
|
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
|
|
"""Get the first available embedding model identifier."""
|
|
try:
|
|
# Get all models from the routing table
|
|
all_models = await self.routing_table.get_all_with_type("model")
|
|
|
|
# Filter for embedding models
|
|
embedding_models = [
|
|
model
|
|
for model in all_models
|
|
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
|
|
]
|
|
|
|
if embedding_models:
|
|
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
|
|
if dimension is None:
|
|
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
|
|
return embedding_models[0].identifier, dimension
|
|
else:
|
|
logger.warning("No embedding models found in the routing table")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting embedding models: {e}")
|
|
return None
|
|
|
|
async def register_vector_db(
|
|
self,
|
|
vector_db_id: str,
|
|
embedding_model: str,
|
|
embedding_dimension: int | None = 384,
|
|
provider_id: str | None = None,
|
|
vector_db_name: str | None = None,
|
|
provider_vector_db_id: str | None = None,
|
|
) -> None:
|
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
|
await self.routing_table.register_vector_db(
|
|
vector_db_id,
|
|
embedding_model,
|
|
embedding_dimension,
|
|
provider_id,
|
|
vector_db_name,
|
|
provider_vector_db_id,
|
|
)
|
|
|
|
async def insert_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
chunks: list[Chunk],
|
|
ttl_seconds: int | None = None,
|
|
) -> None:
|
|
logger.debug(
|
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
|
)
|
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
|
|
|
async def query_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
query: InterleavedContent,
|
|
params: dict[str, Any] | None = None,
|
|
) -> QueryChunksResponse:
|
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
|
|
|
# OpenAI Vector Stores API endpoints
|
|
async def openai_create_vector_store(
|
|
self,
|
|
name: str,
|
|
file_ids: list[str] | None = None,
|
|
expires_after: dict[str, Any] | None = None,
|
|
chunking_strategy: dict[str, Any] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
embedding_model: str | None = None,
|
|
embedding_dimension: int | None = None,
|
|
provider_id: str | None = None,
|
|
) -> VectorStoreObject:
|
|
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
|
|
|
# If no embedding model is provided, use the first available one
|
|
if embedding_model is None:
|
|
embedding_model_info = await self._get_first_embedding_model()
|
|
if embedding_model_info is None:
|
|
raise ValueError("No embedding model provided and no embedding models available in the system")
|
|
embedding_model, embedding_dimension = embedding_model_info
|
|
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
|
|
|
vector_db_id = f"vs_{uuid.uuid4()}"
|
|
registered_vector_db = await self.routing_table.register_vector_db(
|
|
vector_db_id=vector_db_id,
|
|
embedding_model=embedding_model,
|
|
embedding_dimension=embedding_dimension,
|
|
provider_id=provider_id,
|
|
provider_vector_db_id=vector_db_id,
|
|
vector_db_name=name,
|
|
)
|
|
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
|
|
name=name,
|
|
file_ids=file_ids,
|
|
expires_after=expires_after,
|
|
chunking_strategy=chunking_strategy,
|
|
metadata=metadata,
|
|
embedding_model=embedding_model,
|
|
embedding_dimension=embedding_dimension,
|
|
provider_id=registered_vector_db.provider_id,
|
|
provider_vector_db_id=registered_vector_db.provider_resource_id,
|
|
)
|
|
|
|
async def openai_list_vector_stores(
|
|
self,
|
|
limit: int | None = 20,
|
|
order: str | None = "desc",
|
|
after: str | None = None,
|
|
before: str | None = None,
|
|
) -> VectorStoreListResponse:
|
|
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
|
# Route to default provider for now - could aggregate from all providers in the future
|
|
# call retrieve on each vector dbs to get list of vector stores
|
|
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
|
all_stores = []
|
|
for vector_db in vector_dbs:
|
|
try:
|
|
vector_store = await self.routing_table.get_provider_impl(
|
|
vector_db.identifier
|
|
).openai_retrieve_vector_store(vector_db.identifier)
|
|
all_stores.append(vector_store)
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
|
continue
|
|
|
|
# Sort by created_at
|
|
reverse_order = order == "desc"
|
|
all_stores.sort(key=lambda x: x.created_at, reverse=reverse_order)
|
|
|
|
# Apply cursor-based pagination
|
|
if after:
|
|
after_index = next((i for i, store in enumerate(all_stores) if store.id == after), -1)
|
|
if after_index >= 0:
|
|
all_stores = all_stores[after_index + 1 :]
|
|
|
|
if before:
|
|
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
|
|
all_stores = all_stores[:before_index]
|
|
|
|
# Apply limit
|
|
limited_stores = all_stores[:limit]
|
|
|
|
# Determine pagination info
|
|
has_more = len(all_stores) > limit
|
|
first_id = limited_stores[0].id if limited_stores else None
|
|
last_id = limited_stores[-1].id if limited_stores else None
|
|
|
|
return VectorStoreListResponse(
|
|
data=limited_stores,
|
|
has_more=has_more,
|
|
first_id=first_id,
|
|
last_id=last_id,
|
|
)
|
|
|
|
async def openai_retrieve_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
) -> VectorStoreObject:
|
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_retrieve_vector_store(vector_store_id)
|
|
|
|
async def openai_update_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
name: str | None = None,
|
|
expires_after: dict[str, Any] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> VectorStoreObject:
|
|
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_update_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
name=name,
|
|
expires_after=expires_after,
|
|
metadata=metadata,
|
|
)
|
|
|
|
async def openai_delete_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
) -> VectorStoreDeleteResponse:
|
|
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
result = await provider.openai_delete_vector_store(vector_store_id)
|
|
# drop from registry
|
|
await self.routing_table.unregister_vector_db(vector_store_id)
|
|
return result
|
|
|
|
async def openai_search_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
query: str | list[str],
|
|
filters: dict[str, Any] | None = None,
|
|
max_num_results: int | None = 10,
|
|
ranking_options: SearchRankingOptions | None = None,
|
|
rewrite_query: bool | None = False,
|
|
search_mode: str | None = "vector",
|
|
) -> VectorStoreSearchResponsePage:
|
|
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_search_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
query=query,
|
|
filters=filters,
|
|
max_num_results=max_num_results,
|
|
ranking_options=ranking_options,
|
|
rewrite_query=rewrite_query,
|
|
search_mode=search_mode,
|
|
)
|
|
|
|
async def openai_attach_file_to_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
attributes: dict[str, Any] | None = None,
|
|
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
|
) -> VectorStoreFileObject:
|
|
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_attach_file_to_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
attributes=attributes,
|
|
chunking_strategy=chunking_strategy,
|
|
)
|
|
|
|
async def openai_list_files_in_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
limit: int | None = 20,
|
|
order: str | None = "desc",
|
|
after: str | None = None,
|
|
before: str | None = None,
|
|
filter: VectorStoreFileStatus | None = None,
|
|
) -> list[VectorStoreFileObject]:
|
|
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_list_files_in_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
limit=limit,
|
|
order=order,
|
|
after=after,
|
|
before=before,
|
|
filter=filter,
|
|
)
|
|
|
|
async def openai_retrieve_vector_store_file(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
) -> VectorStoreFileObject:
|
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_retrieve_vector_store_file(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
)
|
|
|
|
async def openai_retrieve_vector_store_file_contents(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
) -> VectorStoreFileContentsResponse:
|
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_retrieve_vector_store_file_contents(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
)
|
|
|
|
async def openai_update_vector_store_file(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
attributes: dict[str, Any],
|
|
) -> VectorStoreFileObject:
|
|
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_update_vector_store_file(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
attributes=attributes,
|
|
)
|
|
|
|
async def openai_delete_vector_store_file(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
) -> VectorStoreFileDeleteResponse:
|
|
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_delete_vector_store_file(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
)
|
|
|
|
async def health(self) -> dict[str, HealthResponse]:
|
|
health_statuses = {}
|
|
timeout = 1 # increasing the timeout to 1 second for health checks
|
|
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
|
try:
|
|
# check if the provider has a health method
|
|
if not hasattr(impl, "health"):
|
|
continue
|
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
|
health_statuses[provider_id] = health
|
|
except TimeoutError:
|
|
health_statuses[provider_id] = HealthResponse(
|
|
status=HealthStatus.ERROR,
|
|
message=f"Health check timed out after {timeout} seconds",
|
|
)
|
|
except NotImplementedError:
|
|
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
|
except Exception as e:
|
|
health_statuses[provider_id] = HealthResponse(
|
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
|
)
|
|
return health_statuses
|