mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? This is an initial working prototype of wiring up the `file_search` builtin tool for the Responses API to our existing rag knowledge search tool. This is me seeing what I could pull together on top of the bits we already have merged. This may not be the ideal way to implement this, and things like how I shuffle the vector store ids from the original response API tool request to the actual tool execution feel a bit hacky (grep for `tool_kwargs["vector_db_ids"]` in `_execute_tool_call` to see what I mean). ## Test Plan I stubbed in some new tests to exercise this using text and pdf documents. Note that this is currently under tests/verification only because it sometimes flakes with tool calling of the small Llama-3.2-3B model we run in CI (and that I use as an example below). We'd want to make the test a bit more robust in some way if we moved this over to tests/integration and ran it in CI. ### OpenAI SaaS (to verify test correctness) ``` pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search' \ --base-url=https://api.openai.com/v1 \ --model=gpt-4o ``` ### Fireworks with faiss vector store ``` llama stack run llama_stack/templates/fireworks/run.yaml pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=meta-llama/Llama-3.3-70B-Instruct ``` ### Ollama with faiss vector store This sometimes flakes on Ollama because the quantized small model doesn't always choose to call the tool to answer the user's question. But, it often works. ``` ollama run llama3.2:3b INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \ llama stack run ./llama_stack/templates/ollama/run.yaml \ --image-type venv \ --env OLLAMA_URL="http://0.0.0.0:11434" pytest -sv tests/verifications/openai_api/test_responses.py \ -k'file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=meta-llama/Llama-3.2-3B-Instruct ``` ### OpenAI provider with sqlite-vec vector store ``` llama stack run ./llama_stack/templates/starter/run.yaml --image-type venv pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=openai/gpt-4o-mini ``` ### Ensure existing vector store integration tests still pass ``` ollama run llama3.2:3b INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \ llama stack run ./llama_stack/templates/ollama/run.yaml \ --image-type venv \ --env OLLAMA_URL="http://0.0.0.0:11434" LLAMA_STACK_CONFIG=http://localhost:8321 \ pytest -sv tests/integration/vector_io \ --text-model "meta-llama/Llama-3.2-3B-Instruct" \ --embedding-model=all-MiniLM-L6-v2 ``` --------- Signed-off-by: Ben Browning <bbrownin@redhat.com>
274 lines
11 KiB
Python
274 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
)
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.apis.vector_io import (
|
|
Chunk,
|
|
QueryChunksResponse,
|
|
VectorIO,
|
|
VectorStoreDeleteResponse,
|
|
VectorStoreListResponse,
|
|
VectorStoreObject,
|
|
VectorStoreSearchResponsePage,
|
|
)
|
|
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import RoutingTable
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class VectorIORouter(VectorIO):
|
|
"""Routes to an provider based on the vector db identifier"""
|
|
|
|
def __init__(
|
|
self,
|
|
routing_table: RoutingTable,
|
|
) -> None:
|
|
logger.debug("Initializing VectorIORouter")
|
|
self.routing_table = routing_table
|
|
|
|
async def initialize(self) -> None:
|
|
logger.debug("VectorIORouter.initialize")
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.debug("VectorIORouter.shutdown")
|
|
pass
|
|
|
|
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
|
|
"""Get the first available embedding model identifier."""
|
|
try:
|
|
# Get all models from the routing table
|
|
all_models = await self.routing_table.get_all_with_type("model")
|
|
|
|
# Filter for embedding models
|
|
embedding_models = [
|
|
model
|
|
for model in all_models
|
|
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
|
|
]
|
|
|
|
if embedding_models:
|
|
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
|
|
if dimension is None:
|
|
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
|
|
return embedding_models[0].identifier, dimension
|
|
else:
|
|
logger.warning("No embedding models found in the routing table")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting embedding models: {e}")
|
|
return None
|
|
|
|
async def register_vector_db(
|
|
self,
|
|
vector_db_id: str,
|
|
embedding_model: str,
|
|
embedding_dimension: int | None = 384,
|
|
provider_id: str | None = None,
|
|
provider_vector_db_id: str | None = None,
|
|
) -> None:
|
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
|
await self.routing_table.register_vector_db(
|
|
vector_db_id,
|
|
embedding_model,
|
|
embedding_dimension,
|
|
provider_id,
|
|
provider_vector_db_id,
|
|
)
|
|
|
|
async def insert_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
chunks: list[Chunk],
|
|
ttl_seconds: int | None = None,
|
|
) -> None:
|
|
logger.debug(
|
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
|
)
|
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
|
|
|
async def query_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
query: InterleavedContent,
|
|
params: dict[str, Any] | None = None,
|
|
) -> QueryChunksResponse:
|
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
|
|
|
# OpenAI Vector Stores API endpoints
|
|
async def openai_create_vector_store(
|
|
self,
|
|
name: str,
|
|
file_ids: list[str] | None = None,
|
|
expires_after: dict[str, Any] | None = None,
|
|
chunking_strategy: dict[str, Any] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
embedding_model: str | None = None,
|
|
embedding_dimension: int | None = None,
|
|
provider_id: str | None = None,
|
|
provider_vector_db_id: str | None = None,
|
|
) -> VectorStoreObject:
|
|
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
|
|
|
# If no embedding model is provided, use the first available one
|
|
if embedding_model is None:
|
|
embedding_model_info = await self._get_first_embedding_model()
|
|
if embedding_model_info is None:
|
|
raise ValueError("No embedding model provided and no embedding models available in the system")
|
|
embedding_model, embedding_dimension = embedding_model_info
|
|
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
|
|
|
vector_db_id = name
|
|
registered_vector_db = await self.routing_table.register_vector_db(
|
|
vector_db_id,
|
|
embedding_model,
|
|
embedding_dimension,
|
|
provider_id,
|
|
provider_vector_db_id,
|
|
)
|
|
|
|
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
|
|
vector_db_id,
|
|
file_ids=file_ids,
|
|
expires_after=expires_after,
|
|
chunking_strategy=chunking_strategy,
|
|
metadata=metadata,
|
|
embedding_model=embedding_model,
|
|
embedding_dimension=embedding_dimension,
|
|
provider_id=registered_vector_db.provider_id,
|
|
provider_vector_db_id=registered_vector_db.provider_resource_id,
|
|
)
|
|
|
|
async def openai_list_vector_stores(
|
|
self,
|
|
limit: int | None = 20,
|
|
order: str | None = "desc",
|
|
after: str | None = None,
|
|
before: str | None = None,
|
|
) -> VectorStoreListResponse:
|
|
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
|
# Route to default provider for now - could aggregate from all providers in the future
|
|
# call retrieve on each vector dbs to get list of vector stores
|
|
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
|
all_stores = []
|
|
for vector_db in vector_dbs:
|
|
vector_store = await self.routing_table.get_provider_impl(
|
|
vector_db.identifier
|
|
).openai_retrieve_vector_store(vector_db.identifier)
|
|
all_stores.append(vector_store)
|
|
|
|
# Sort by created_at
|
|
reverse_order = order == "desc"
|
|
all_stores.sort(key=lambda x: x.created_at, reverse=reverse_order)
|
|
|
|
# Apply cursor-based pagination
|
|
if after:
|
|
after_index = next((i for i, store in enumerate(all_stores) if store.id == after), -1)
|
|
if after_index >= 0:
|
|
all_stores = all_stores[after_index + 1 :]
|
|
|
|
if before:
|
|
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
|
|
all_stores = all_stores[:before_index]
|
|
|
|
# Apply limit
|
|
limited_stores = all_stores[:limit]
|
|
|
|
# Determine pagination info
|
|
has_more = len(all_stores) > limit
|
|
first_id = limited_stores[0].id if limited_stores else None
|
|
last_id = limited_stores[-1].id if limited_stores else None
|
|
|
|
return VectorStoreListResponse(
|
|
data=limited_stores,
|
|
has_more=has_more,
|
|
first_id=first_id,
|
|
last_id=last_id,
|
|
)
|
|
|
|
async def openai_retrieve_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
) -> VectorStoreObject:
|
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_retrieve_vector_store(vector_store_id)
|
|
|
|
async def openai_update_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
name: str | None = None,
|
|
expires_after: dict[str, Any] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> VectorStoreObject:
|
|
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_update_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
name=name,
|
|
expires_after=expires_after,
|
|
metadata=metadata,
|
|
)
|
|
|
|
async def openai_delete_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
) -> VectorStoreDeleteResponse:
|
|
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
result = await provider.openai_delete_vector_store(vector_store_id)
|
|
# drop from registry
|
|
await self.routing_table.unregister_vector_db(vector_store_id)
|
|
return result
|
|
|
|
async def openai_search_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
query: str | list[str],
|
|
filters: dict[str, Any] | None = None,
|
|
max_num_results: int | None = 10,
|
|
ranking_options: dict[str, Any] | None = None,
|
|
rewrite_query: bool | None = False,
|
|
) -> VectorStoreSearchResponsePage:
|
|
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_search_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
query=query,
|
|
filters=filters,
|
|
max_num_results=max_num_results,
|
|
ranking_options=ranking_options,
|
|
rewrite_query=rewrite_query,
|
|
)
|
|
|
|
async def openai_attach_file_to_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
attributes: dict[str, Any] | None = None,
|
|
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
|
) -> VectorStoreFileObject:
|
|
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_attach_file_to_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
attributes=attributes,
|
|
chunking_strategy=chunking_strategy,
|
|
)
|