mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
refactor to only configuration of model at build time
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
2cc7943fd6
commit
d887f1f8bb
31 changed files with 280 additions and 315 deletions
|
|
@ -18,6 +18,7 @@ from llama_stack.core.storage.datatypes import (
|
||||||
StorageConfig,
|
StorageConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.log import LoggingConfig
|
from llama_stack.log import LoggingConfig
|
||||||
|
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Api,
|
Api,
|
||||||
Benchmark,
|
Benchmark,
|
||||||
|
|
@ -381,9 +382,17 @@ class VectorStoresConfig(BaseModel):
|
||||||
description="Default LLM model for query expansion/rewriting in vector search.",
|
description="Default LLM model for query expansion/rewriting in vector search.",
|
||||||
)
|
)
|
||||||
query_expansion_prompt: str = Field(
|
query_expansion_prompt: str = Field(
|
||||||
default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:",
|
default=DEFAULT_QUERY_EXPANSION_PROMPT,
|
||||||
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
|
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
|
||||||
)
|
)
|
||||||
|
query_expansion_max_tokens: int = Field(
|
||||||
|
default=100,
|
||||||
|
description="Maximum number of tokens for query expansion responses.",
|
||||||
|
)
|
||||||
|
query_expansion_temperature: float = Field(
|
||||||
|
default=0.3,
|
||||||
|
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SafetyConfig(BaseModel):
|
class SafetyConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -374,13 +374,6 @@ async def instantiate_provider(
|
||||||
method = "get_adapter_impl"
|
method = "get_adapter_impl"
|
||||||
args = [config, deps]
|
args = [config, deps]
|
||||||
|
|
||||||
# Add vector_stores_config for vector_io providers
|
|
||||||
if (
|
|
||||||
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
|
|
||||||
and provider_spec.api == Api.vector_io
|
|
||||||
):
|
|
||||||
args.append(run_config.vector_stores)
|
|
||||||
|
|
||||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||||
method = "get_auto_router_impl"
|
method = "get_auto_router_impl"
|
||||||
|
|
||||||
|
|
@ -401,11 +394,6 @@ async def instantiate_provider(
|
||||||
args.append(policy)
|
args.append(policy)
|
||||||
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
|
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
|
||||||
args.append(run_config.telemetry.enabled)
|
args.append(run_config.telemetry.enabled)
|
||||||
if (
|
|
||||||
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
|
|
||||||
and provider_spec.api == Api.vector_io
|
|
||||||
):
|
|
||||||
args.append(run_config.vector_stores)
|
|
||||||
|
|
||||||
fn = getattr(module, method)
|
fn = getattr(module, method)
|
||||||
impl = await fn(*args)
|
impl = await fn(*args)
|
||||||
|
|
|
||||||
|
|
@ -99,19 +99,6 @@ class VectorIORouter(VectorIO):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
|
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||||
|
|
||||||
# Ensure params dict exists and add vector_stores_config for query rewriting
|
|
||||||
if params is None:
|
|
||||||
params = {}
|
|
||||||
|
|
||||||
logger.debug(f"Router vector_stores_config: {self.vector_stores_config}")
|
|
||||||
if self.vector_stores_config and hasattr(self.vector_stores_config, "default_query_expansion_model"):
|
|
||||||
logger.debug(
|
|
||||||
f"Router default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
|
|
||||||
)
|
|
||||||
|
|
||||||
params["vector_stores_config"] = self.vector_stores_config
|
|
||||||
|
|
||||||
return await provider.query_chunks(vector_store_id, query, params)
|
return await provider.query_chunks(vector_store_id, query, params)
|
||||||
|
|
||||||
# OpenAI Vector Stores API endpoints
|
# OpenAI Vector Stores API endpoints
|
||||||
|
|
|
||||||
|
|
@ -144,16 +144,17 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
|
||||||
if vector_stores_config is None:
|
if vector_stores_config is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Validate default embedding model
|
||||||
default_embedding_model = vector_stores_config.default_embedding_model
|
default_embedding_model = vector_stores_config.default_embedding_model
|
||||||
if default_embedding_model is None:
|
if default_embedding_model is not None:
|
||||||
return
|
|
||||||
|
|
||||||
provider_id = default_embedding_model.provider_id
|
provider_id = default_embedding_model.provider_id
|
||||||
model_id = default_embedding_model.model_id
|
model_id = default_embedding_model.model_id
|
||||||
default_model_id = f"{provider_id}/{model_id}"
|
default_model_id = f"{provider_id}/{model_id}"
|
||||||
|
|
||||||
if Api.models not in impls:
|
if Api.models not in impls:
|
||||||
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
|
raise ValueError(
|
||||||
|
f"Models API is not available but vector_stores config requires model '{default_model_id}'"
|
||||||
|
)
|
||||||
|
|
||||||
models_impl = impls[Api.models]
|
models_impl = impls[Api.models]
|
||||||
response = await models_impl.list_models()
|
response = await models_impl.list_models()
|
||||||
|
|
@ -161,7 +162,9 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
|
||||||
|
|
||||||
default_model = models_list.get(default_model_id)
|
default_model = models_list.get(default_model_id)
|
||||||
if default_model is None:
|
if default_model is None:
|
||||||
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
|
raise ValueError(
|
||||||
|
f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}"
|
||||||
|
)
|
||||||
|
|
||||||
embedding_dimension = default_model.metadata.get("embedding_dimension")
|
embedding_dimension = default_model.metadata.get("embedding_dimension")
|
||||||
if embedding_dimension is None:
|
if embedding_dimension is None:
|
||||||
|
|
@ -174,6 +177,30 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
|
||||||
|
|
||||||
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
|
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
|
||||||
|
|
||||||
|
# Validate default query expansion model
|
||||||
|
default_query_expansion_model = vector_stores_config.default_query_expansion_model
|
||||||
|
if default_query_expansion_model is not None:
|
||||||
|
provider_id = default_query_expansion_model.provider_id
|
||||||
|
model_id = default_query_expansion_model.model_id
|
||||||
|
query_model_id = f"{provider_id}/{model_id}"
|
||||||
|
|
||||||
|
if Api.models not in impls:
|
||||||
|
raise ValueError(
|
||||||
|
f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
models_impl = impls[Api.models]
|
||||||
|
response = await models_impl.list_models()
|
||||||
|
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
|
||||||
|
|
||||||
|
query_expansion_model = llm_models_list.get(query_model_id)
|
||||||
|
if query_expansion_model is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Validated default query expansion model: {query_model_id}")
|
||||||
|
|
||||||
|
|
||||||
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
|
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
|
||||||
if safety_config is None or safety_config.default_shield_id is None:
|
if safety_config is None or safety_config.default_shield_id is None:
|
||||||
|
|
@ -437,6 +464,12 @@ class Stack:
|
||||||
await refresh_registry_once(impls)
|
await refresh_registry_once(impls)
|
||||||
await validate_vector_stores_config(self.run_config.vector_stores, impls)
|
await validate_vector_stores_config(self.run_config.vector_stores, impls)
|
||||||
await validate_safety_config(self.run_config.safety, impls)
|
await validate_safety_config(self.run_config.safety, impls)
|
||||||
|
|
||||||
|
# Set global query expansion configuration from stack config
|
||||||
|
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
|
||||||
|
|
||||||
|
set_default_query_expansion_config(self.run_config.vector_stores)
|
||||||
|
|
||||||
self.impls = impls
|
self.impls = impls
|
||||||
|
|
||||||
def create_registry_refresh_task(self):
|
def create_registry_refresh_task(self):
|
||||||
|
|
|
||||||
|
|
@ -296,5 +296,7 @@ vector_stores:
|
||||||
|
|
||||||
|
|
||||||
Improved query:'
|
Improved query:'
|
||||||
|
query_expansion_max_tokens: 100
|
||||||
|
query_expansion_temperature: 0.3
|
||||||
safety:
|
safety:
|
||||||
default_shield_id: llama-guard
|
default_shield_id: llama-guard
|
||||||
|
|
|
||||||
|
|
@ -287,5 +287,7 @@ vector_stores:
|
||||||
|
|
||||||
|
|
||||||
Improved query:'
|
Improved query:'
|
||||||
|
query_expansion_max_tokens: 100
|
||||||
|
query_expansion_temperature: 0.3
|
||||||
safety:
|
safety:
|
||||||
default_shield_id: llama-guard
|
default_shield_id: llama-guard
|
||||||
|
|
|
||||||
|
|
@ -299,5 +299,7 @@ vector_stores:
|
||||||
|
|
||||||
|
|
||||||
Improved query:'
|
Improved query:'
|
||||||
|
query_expansion_max_tokens: 100
|
||||||
|
query_expansion_temperature: 0.3
|
||||||
safety:
|
safety:
|
||||||
default_shield_id: llama-guard
|
default_shield_id: llama-guard
|
||||||
|
|
|
||||||
|
|
@ -290,5 +290,7 @@ vector_stores:
|
||||||
|
|
||||||
|
|
||||||
Improved query:'
|
Improved query:'
|
||||||
|
query_expansion_max_tokens: 100
|
||||||
|
query_expansion_temperature: 0.3
|
||||||
safety:
|
safety:
|
||||||
default_shield_id: llama-guard
|
default_shield_id: llama-guard
|
||||||
|
|
|
||||||
|
|
@ -296,5 +296,7 @@ vector_stores:
|
||||||
|
|
||||||
|
|
||||||
Improved query:'
|
Improved query:'
|
||||||
|
query_expansion_max_tokens: 100
|
||||||
|
query_expansion_temperature: 0.3
|
||||||
safety:
|
safety:
|
||||||
default_shield_id: llama-guard
|
default_shield_id: llama-guard
|
||||||
|
|
|
||||||
|
|
@ -287,5 +287,7 @@ vector_stores:
|
||||||
|
|
||||||
|
|
||||||
Improved query:'
|
Improved query:'
|
||||||
|
query_expansion_max_tokens: 100
|
||||||
|
query_expansion_temperature: 0.3
|
||||||
safety:
|
safety:
|
||||||
default_shield_id: llama-guard
|
default_shield_id: llama-guard
|
||||||
|
|
|
||||||
|
|
@ -6,19 +6,16 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack_api import Api
|
from llama_stack_api import Api
|
||||||
|
|
||||||
from .config import FaissVectorIOConfig
|
from .config import FaissVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
||||||
config: FaissVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
|
|
||||||
):
|
|
||||||
from .faiss import FaissVectorIOAdapter
|
from .faiss import FaissVectorIOAdapter
|
||||||
|
|
||||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ import faiss # type: ignore[import-untyped]
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
|
|
@ -190,12 +189,10 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
||||||
config: FaissVectorIOConfig,
|
config: FaissVectorIOConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
files_api: Files | None,
|
files_api: Files | None,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
@ -211,7 +208,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
||||||
vector_store,
|
vector_store,
|
||||||
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||||
self.inference_api,
|
self.inference_api,
|
||||||
self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
|
|
||||||
|
|
@ -250,7 +246,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_vector_stores(self) -> list[VectorStore]:
|
async def list_vector_stores(self) -> list[VectorStore]:
|
||||||
|
|
@ -284,7 +279,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store_id] = index
|
self.cache[vector_store_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
||||||
|
|
@ -6,18 +6,15 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack_api import Api
|
from llama_stack_api import Api
|
||||||
|
|
||||||
from .config import SQLiteVectorIOConfig
|
from .config import SQLiteVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
||||||
config: SQLiteVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
|
|
||||||
):
|
|
||||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||||
|
|
||||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ import numpy as np
|
||||||
import sqlite_vec # type: ignore[import-untyped]
|
import sqlite_vec # type: ignore[import-untyped]
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
|
|
@ -391,12 +390,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
||||||
config,
|
config,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
files_api: Files | None,
|
files_api: Files | None,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||||
self.vector_store_table = None
|
self.vector_store_table = None
|
||||||
|
|
||||||
|
|
@ -411,9 +408,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
||||||
index = await SQLiteVecIndex.create(
|
index = await SQLiteVecIndex.create(
|
||||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||||
)
|
)
|
||||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||||
vector_store, index, self.inference_api, self.vector_stores_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load existing OpenAI vector stores into the in-memory cache
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
await self.initialize_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
@ -437,9 +432,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
||||||
index = await SQLiteVecIndex.create(
|
index = await SQLiteVecIndex.create(
|
||||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||||
)
|
)
|
||||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||||
vector_store, index, self.inference_api, self.vector_stores_config
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||||
if vector_store_id in self.cache:
|
if vector_store_id in self.cache:
|
||||||
|
|
@ -464,7 +457,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
||||||
kvstore=self.kvstore,
|
kvstore=self.kvstore,
|
||||||
),
|
),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store_id] = index
|
self.cache[vector_store_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack_api import Api, ProviderSpec
|
from llama_stack_api import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
|
|
||||||
):
|
|
||||||
from .chroma import ChromaVectorIOAdapter
|
from .chroma import ChromaVectorIOAdapter
|
||||||
|
|
||||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from urllib.parse import urlparse
|
||||||
import chromadb
|
import chromadb
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||||
|
|
@ -126,13 +125,11 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
files_api: Files | None,
|
files_api: Files | None,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.client = None
|
self.client = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.vector_store_table = None
|
self.vector_store_table = None
|
||||||
|
|
@ -165,7 +162,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||||
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
|
vector_store, ChromaIndex(self.client, collection), self.inference_api
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||||
|
|
@ -210,9 +207,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
collection = await maybe_await(self.client.get_collection(vector_store_id))
|
collection = await maybe_await(self.client.get_collection(vector_store_id))
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
|
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
|
||||||
index = VectorStoreWithIndex(
|
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
|
||||||
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
|
|
||||||
)
|
|
||||||
self.cache[vector_store_id] = index
|
self.cache[vector_store_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,18 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack_api import Api, ProviderSpec
|
from llama_stack_api import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import MilvusVectorIOConfig
|
from .config import MilvusVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
|
|
||||||
):
|
|
||||||
from .milvus import MilvusVectorIOAdapter
|
from .milvus import MilvusVectorIOAdapter
|
||||||
|
|
||||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from typing import Any
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||||
|
|
@ -273,14 +272,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
files_api: Files | None,
|
files_api: Files | None,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.client = None
|
self.client = None
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.vector_store_table = None
|
self.vector_store_table = None
|
||||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
|
|
@ -301,7 +298,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
kvstore=self.kvstore,
|
kvstore=self.kvstore,
|
||||||
),
|
),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
|
|
@ -329,7 +325,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
|
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
|
|
@ -352,7 +347,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
|
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store_id] = index
|
self.cache[vector_store_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack_api import Api, ProviderSpec
|
from llama_stack_api import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import PGVectorVectorIOConfig
|
from .config import PGVectorVectorIOConfig
|
||||||
|
|
@ -13,10 +12,9 @@ from .config import PGVectorVectorIOConfig
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(
|
||||||
config: PGVectorVectorIOConfig,
|
config: PGVectorVectorIOConfig,
|
||||||
deps: dict[Api, ProviderSpec],
|
deps: dict[Api, ProviderSpec],
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
):
|
):
|
||||||
from .pgvector import PGVectorVectorIOAdapter
|
from .pgvector import PGVectorVectorIOAdapter
|
||||||
|
|
||||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from psycopg2 import sql
|
||||||
from psycopg2.extras import Json, execute_values
|
from psycopg2.extras import Json, execute_values
|
||||||
from pydantic import BaseModel, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
|
|
@ -335,12 +334,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
||||||
config: PGVectorVectorIOConfig,
|
config: PGVectorVectorIOConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
files_api: Files | None = None,
|
files_api: Files | None = None,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.vector_store_table = None
|
self.vector_store_table = None
|
||||||
|
|
@ -396,7 +393,6 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
||||||
vector_store,
|
vector_store,
|
||||||
index=pgvector_index,
|
index=pgvector_index,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
|
|
||||||
|
|
@ -428,7 +424,6 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
||||||
vector_store,
|
vector_store,
|
||||||
index=pgvector_index,
|
index=pgvector_index,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
|
|
||||||
|
|
@ -469,9 +464,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
||||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||||
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
|
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
|
||||||
await index.initialize()
|
await index.initialize()
|
||||||
self.cache[vector_store_id] = VectorStoreWithIndex(
|
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||||
vector_store, index, self.inference_api, self.vector_stores_config
|
|
||||||
)
|
|
||||||
return self.cache[vector_store_id]
|
return self.cache[vector_store_id]
|
||||||
|
|
||||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack_api import Api, ProviderSpec
|
from llama_stack_api import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import QdrantVectorIOConfig
|
from .config import QdrantVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
|
|
||||||
):
|
|
||||||
from .qdrant import QdrantVectorIOAdapter
|
from .qdrant import QdrantVectorIOAdapter
|
||||||
|
|
||||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from numpy.typing import NDArray
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||||
|
|
@ -153,14 +152,12 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
files_api: Files | None = None,
|
files_api: Files | None = None,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client: AsyncQdrantClient = None
|
self.client: AsyncQdrantClient = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.vector_store_table = None
|
self.vector_store_table = None
|
||||||
self._qdrant_lock = asyncio.Lock()
|
self._qdrant_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
@ -179,7 +176,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
vector_store,
|
vector_store,
|
||||||
QdrantIndex(self.client, vector_store.identifier),
|
QdrantIndex(self.client, vector_store.identifier),
|
||||||
self.inference_api,
|
self.inference_api,
|
||||||
self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
@ -199,7 +195,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
index=QdrantIndex(self.client, vector_store.identifier),
|
index=QdrantIndex(self.client, vector_store.identifier),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
|
|
@ -231,7 +226,6 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
|
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store_id] = index
|
self.cache[vector_store_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack_api import Api, ProviderSpec
|
from llama_stack_api import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import WeaviateVectorIOConfig
|
from .config import WeaviateVectorIOConfig
|
||||||
|
|
@ -13,10 +12,9 @@ from .config import WeaviateVectorIOConfig
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(
|
||||||
config: WeaviateVectorIOConfig,
|
config: WeaviateVectorIOConfig,
|
||||||
deps: dict[Api, ProviderSpec],
|
deps: dict[Api, ProviderSpec],
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
):
|
):
|
||||||
from .weaviate import WeaviateVectorIOAdapter
|
from .weaviate import WeaviateVectorIOAdapter
|
||||||
|
|
||||||
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
|
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ from numpy.typing import NDArray
|
||||||
from weaviate.classes.init import Auth
|
from weaviate.classes.init import Auth
|
||||||
from weaviate.classes.query import Filter, HybridFusion
|
from weaviate.classes.query import Filter, HybridFusion
|
||||||
|
|
||||||
from llama_stack.core.datatypes import VectorStoresConfig
|
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -268,12 +267,10 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
||||||
config: WeaviateVectorIOConfig,
|
config: WeaviateVectorIOConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
files_api: Files | None,
|
files_api: Files | None,
|
||||||
vector_stores_config: VectorStoresConfig | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.vector_stores_config = vector_stores_config
|
|
||||||
self.client_cache = {}
|
self.client_cache = {}
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.vector_store_table = None
|
self.vector_store_table = None
|
||||||
|
|
@ -319,7 +316,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
index=idx,
|
index=idx,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load OpenAI vector stores metadata into cache
|
# Load OpenAI vector stores metadata into cache
|
||||||
|
|
@ -348,7 +344,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
||||||
vector_store,
|
vector_store,
|
||||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
||||||
self.inference_api,
|
self.inference_api,
|
||||||
self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||||
|
|
@ -383,7 +378,6 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
|
index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
vector_stores_config=self.vector_stores_config,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store_id] = index
|
self.cache[vector_store_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
||||||
|
|
||||||
|
__all__ = ["DEFAULT_QUERY_EXPANSION_PROMPT"]
|
||||||
|
|
|
||||||
8
src/llama_stack/providers/utils/memory/constants.py
Normal file
8
src/llama_stack/providers/utils/memory/constants.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Default prompt template for query expansion in vector search
|
||||||
|
DEFAULT_QUERY_EXPANSION_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:"
|
||||||
|
|
@ -379,11 +379,6 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}"
|
f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract query expansion model from extra_body if provided
|
|
||||||
query_expansion_model = extra_body.get("query_expansion_model")
|
|
||||||
if query_expansion_model:
|
|
||||||
logger.debug(f"Using per-store query expansion model: {query_expansion_model}")
|
|
||||||
|
|
||||||
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
|
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
|
||||||
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
|
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
|
||||||
# Derive the canonical vector_store_id (allow override, else generate)
|
# Derive the canonical vector_store_id (allow override, else generate)
|
||||||
|
|
@ -407,7 +402,6 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=vector_store_id,
|
provider_resource_id=vector_store_id,
|
||||||
vector_store_name=params.name,
|
vector_store_name=params.name,
|
||||||
query_expansion_model=query_expansion_model,
|
|
||||||
)
|
)
|
||||||
await self.register_vector_store(vector_store)
|
await self.register_vector_store(vector_store)
|
||||||
|
|
||||||
|
|
@ -621,9 +615,6 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
"rewrite_query": rewrite_query,
|
"rewrite_query": rewrite_query,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add vector_stores_config if available (for query rewriting)
|
|
||||||
if hasattr(self, "vector_stores_config"):
|
|
||||||
params["vector_stores_config"] = self.vector_stores_config
|
|
||||||
# TODO: Add support for ranking_options.ranker
|
# TODO: Add support for ranking_options.ranker
|
||||||
|
|
||||||
response = await self.query_chunks(
|
response = await self.query_chunks(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
|
||||||
|
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
||||||
|
|
||||||
|
# Global configuration for query expansion - set during stack startup
|
||||||
|
_DEFAULT_QUERY_EXPANSION_MODEL: QualifiedModel | None = None
|
||||||
|
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS: int = 100
|
||||||
|
_DEFAULT_QUERY_EXPANSION_TEMPERATURE: float = 0.3
|
||||||
|
_QUERY_EXPANSION_PROMPT_OVERRIDE: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_query_expansion_config(vector_stores_config: VectorStoresConfig | None):
|
||||||
|
"""Set the global default query expansion configuration from stack config."""
|
||||||
|
global \
|
||||||
|
_DEFAULT_QUERY_EXPANSION_MODEL, \
|
||||||
|
_QUERY_EXPANSION_PROMPT_OVERRIDE, \
|
||||||
|
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS, \
|
||||||
|
_DEFAULT_QUERY_EXPANSION_TEMPERATURE
|
||||||
|
if vector_stores_config:
|
||||||
|
_DEFAULT_QUERY_EXPANSION_MODEL = vector_stores_config.default_query_expansion_model
|
||||||
|
# Only set override if user provided a custom prompt different from default
|
||||||
|
if vector_stores_config.query_expansion_prompt != DEFAULT_QUERY_EXPANSION_PROMPT:
|
||||||
|
_QUERY_EXPANSION_PROMPT_OVERRIDE = vector_stores_config.query_expansion_prompt
|
||||||
|
else:
|
||||||
|
_QUERY_EXPANSION_PROMPT_OVERRIDE = None
|
||||||
|
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS = vector_stores_config.query_expansion_max_tokens
|
||||||
|
_DEFAULT_QUERY_EXPANSION_TEMPERATURE = vector_stores_config.query_expansion_temperature
|
||||||
|
else:
|
||||||
|
_DEFAULT_QUERY_EXPANSION_MODEL = None
|
||||||
|
_QUERY_EXPANSION_PROMPT_OVERRIDE = None
|
||||||
|
_DEFAULT_QUERY_EXPANSION_MAX_TOKENS = 100
|
||||||
|
_DEFAULT_QUERY_EXPANSION_TEMPERATURE = 0.3
|
||||||
|
|
@ -17,7 +17,6 @@ import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
@ -30,19 +29,18 @@ from llama_stack_api import (
|
||||||
Chunk,
|
Chunk,
|
||||||
ChunkMetadata,
|
ChunkMetadata,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
OpenAIEmbeddingsRequestWithExtraBody,
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
VectorStore,
|
VectorStore,
|
||||||
)
|
)
|
||||||
from llama_stack_api.inference import (
|
|
||||||
OpenAIChatCompletionRequestWithExtraBody,
|
|
||||||
OpenAIUserMessageParam,
|
|
||||||
)
|
|
||||||
from llama_stack_api.models import ModelType
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="providers::utils")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.memory import query_expansion_config
|
||||||
|
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
||||||
|
|
||||||
|
|
||||||
class ChunkForDeletion(BaseModel):
|
class ChunkForDeletion(BaseModel):
|
||||||
"""Information needed to delete a chunk from a vector store.
|
"""Information needed to delete a chunk from a vector store.
|
||||||
|
|
@ -268,7 +266,6 @@ class VectorStoreWithIndex:
|
||||||
vector_store: VectorStore
|
vector_store: VectorStore
|
||||||
index: EmbeddingIndex
|
index: EmbeddingIndex
|
||||||
inference_api: Api.inference
|
inference_api: Api.inference
|
||||||
vector_stores_config: VectorStoresConfig | None = None
|
|
||||||
|
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
|
|
@ -296,6 +293,39 @@ class VectorStoreWithIndex:
|
||||||
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
|
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
|
||||||
await self.index.add_chunks(chunks, embeddings)
|
await self.index.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
async def _rewrite_query_for_file_search(self, query: str) -> str:
|
||||||
|
"""Rewrite a search query using the globally configured LLM model for better retrieval results."""
|
||||||
|
if not query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL:
|
||||||
|
log.debug("No default query expansion model configured, using original query")
|
||||||
|
return query
|
||||||
|
|
||||||
|
model_id = f"{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.provider_id}/{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.model_id}"
|
||||||
|
|
||||||
|
# Use custom prompt from config if provided, otherwise use built-in default
|
||||||
|
# Users only need to configure the model - prompt is automatic with optional override
|
||||||
|
if query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE:
|
||||||
|
# Custom prompt from config - format if it contains {query} placeholder
|
||||||
|
prompt = (
|
||||||
|
query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE.format(query=query)
|
||||||
|
if "{query}" in query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE
|
||||||
|
else query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use built-in default prompt and format with query
|
||||||
|
prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query=query)
|
||||||
|
|
||||||
|
request = OpenAIChatCompletionRequestWithExtraBody(
|
||||||
|
model=model_id,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
max_tokens=query_expansion_config._DEFAULT_QUERY_EXPANSION_MAX_TOKENS,
|
||||||
|
temperature=query_expansion_config._DEFAULT_QUERY_EXPANSION_TEMPERATURE,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.inference_api.openai_chat_completion(request)
|
||||||
|
rewritten_query = response.choices[0].message.content.strip()
|
||||||
|
log.debug(f"Query rewritten: '{query}' → '{rewritten_query}'")
|
||||||
|
return rewritten_query
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self,
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
|
|
@ -304,10 +334,6 @@ class VectorStoreWithIndex:
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
# Extract configuration if provided by router
|
|
||||||
if "vector_stores_config" in params:
|
|
||||||
self.vector_stores_config = params["vector_stores_config"]
|
|
||||||
|
|
||||||
k = params.get("max_chunks", 3)
|
k = params.get("max_chunks", 3)
|
||||||
mode = params.get("mode")
|
mode = params.get("mode")
|
||||||
score_threshold = params.get("score_threshold", 0.0)
|
score_threshold = params.get("score_threshold", 0.0)
|
||||||
|
|
@ -331,18 +357,9 @@ class VectorStoreWithIndex:
|
||||||
|
|
||||||
query_string = interleaved_content_as_str(query)
|
query_string = interleaved_content_as_str(query)
|
||||||
|
|
||||||
# Apply query rewriting if enabled
|
# Apply query rewriting if enabled and model is configured
|
||||||
if params.get("rewrite_query", False):
|
if params.get("rewrite_query", False):
|
||||||
if self.vector_stores_config:
|
query_string = await self._rewrite_query_for_file_search(query_string)
|
||||||
log.debug(f"VectorStoreWithIndex received config: {self.vector_stores_config}")
|
|
||||||
if hasattr(self.vector_stores_config, "default_query_expansion_model"):
|
|
||||||
log.debug(
|
|
||||||
f"Config has default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log.debug("No vector_stores_config found - cannot perform query rewriting")
|
|
||||||
|
|
||||||
query_string = await self._rewrite_query_for_search(query_string)
|
|
||||||
|
|
||||||
if mode == "keyword":
|
if mode == "keyword":
|
||||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||||
|
|
@ -359,88 +376,3 @@ class VectorStoreWithIndex:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||||
|
|
||||||
async def _rewrite_query_for_search(self, query: str) -> str:
|
|
||||||
"""Rewrite the user query to improve vector search performance.
|
|
||||||
|
|
||||||
:param query: The original user query
|
|
||||||
:returns: The rewritten query optimized for vector search
|
|
||||||
"""
|
|
||||||
expansion_model = None
|
|
||||||
|
|
||||||
# Check for per-store query expansion model first
|
|
||||||
if self.vector_store.query_expansion_model:
|
|
||||||
# Parse the model string into provider_id and model_id
|
|
||||||
model_parts = self.vector_store.query_expansion_model.split("/", 1)
|
|
||||||
if len(model_parts) == 2:
|
|
||||||
expansion_model = QualifiedModel(provider_id=model_parts[0], model_id=model_parts[1])
|
|
||||||
log.debug(f"Using per-store query expansion model: {expansion_model}")
|
|
||||||
else:
|
|
||||||
log.warning(
|
|
||||||
f"Invalid query_expansion_model format: {self.vector_store.query_expansion_model}. Expected 'provider_id/model_id'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fall back to global default if no per-store model
|
|
||||||
if not expansion_model:
|
|
||||||
if not self.vector_stores_config:
|
|
||||||
raise ValueError(
|
|
||||||
f"No vector_stores_config found and no per-store query_expansion_model! self.vector_stores_config is: {self.vector_stores_config}"
|
|
||||||
)
|
|
||||||
if not self.vector_stores_config.default_query_expansion_model:
|
|
||||||
raise ValueError(
|
|
||||||
f"No default_query_expansion_model configured and no per-store query_expansion_model! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
|
|
||||||
)
|
|
||||||
expansion_model = self.vector_stores_config.default_query_expansion_model
|
|
||||||
log.debug(f"Using global default query expansion model: {expansion_model}")
|
|
||||||
|
|
||||||
chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}"
|
|
||||||
|
|
||||||
# Validate that the model is available and is an LLM
|
|
||||||
try:
|
|
||||||
models_response = await self.inference_api.routing_table.list_models()
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to list available models for validation: {e}") from e
|
|
||||||
|
|
||||||
model_found = False
|
|
||||||
for model in models_response.data:
|
|
||||||
if model.identifier == chat_model:
|
|
||||||
if model.model_type != ModelType.llm:
|
|
||||||
raise ValueError(
|
|
||||||
f"Configured query expansion model '{chat_model}' is not an LLM model "
|
|
||||||
f"(found type: {model.model_type}). Query rewriting requires an LLM model."
|
|
||||||
)
|
|
||||||
model_found = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not model_found:
|
|
||||||
available_llm_models = [m.identifier for m in models_response.data if m.model_type == ModelType.llm]
|
|
||||||
raise ValueError(
|
|
||||||
f"Configured query expansion model '{chat_model}' is not available. "
|
|
||||||
f"Available LLM models: {available_llm_models}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the configured prompt (has a default value)
|
|
||||||
rewrite_prompt = self.vector_stores_config.query_expansion_prompt.format(query=query)
|
|
||||||
|
|
||||||
chat_request = OpenAIChatCompletionRequestWithExtraBody(
|
|
||||||
model=chat_model,
|
|
||||||
messages=[
|
|
||||||
OpenAIUserMessageParam(
|
|
||||||
role="user",
|
|
||||||
content=rewrite_prompt,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
max_tokens=100,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.inference_api.openai_chat_completion(chat_request)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to generate rewritten query: {e}") from e
|
|
||||||
|
|
||||||
if response.choices and len(response.choices) > 0:
|
|
||||||
rewritten_query = response.choices[0].message.content.strip()
|
|
||||||
log.info(f"Query rewritten: '{query}' → '{rewritten_query}'")
|
|
||||||
return rewritten_query
|
|
||||||
else:
|
|
||||||
raise RuntimeError("No response received from LLM model for query rewriting")
|
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,6 @@ class VectorStore(Resource):
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
embedding_dimension: int
|
embedding_dimension: int
|
||||||
vector_store_name: str | None = None
|
vector_store_name: str | None = None
|
||||||
query_expansion_model: str | None = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vector_store_id(self) -> str:
|
def vector_store_id(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1233,94 +1233,122 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
|
||||||
|
|
||||||
|
|
||||||
async def test_query_expansion_functionality(vector_io_adapter):
|
async def test_query_expansion_functionality(vector_io_adapter):
|
||||||
"""Test query expansion with per-store models, global defaults, and error validation."""
|
"""Test query expansion with simplified global configuration approach."""
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
|
from llama_stack.core.datatypes import QualifiedModel
|
||||||
|
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
||||||
|
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
|
||||||
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
|
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
|
||||||
from llama_stack_api.models import Model, ModelType
|
from llama_stack_api import QueryChunksResponse
|
||||||
|
|
||||||
vector_io_adapter.register_vector_store = AsyncMock()
|
# Mock a simple vector store and index
|
||||||
vector_io_adapter.__provider_id__ = "test_provider"
|
|
||||||
|
|
||||||
# Test 1: Per-store model usage
|
|
||||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
|
||||||
name="test_store",
|
|
||||||
metadata={},
|
|
||||||
**{"embedding_model": "test/embedding", "query_expansion_model": "test/llama-model"},
|
|
||||||
)
|
|
||||||
await vector_io_adapter.openai_create_vector_store(params)
|
|
||||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
|
||||||
assert call_args.query_expansion_model == "test/llama-model"
|
|
||||||
|
|
||||||
# Test 2: Global default fallback
|
|
||||||
vector_io_adapter.register_vector_store.reset_mock()
|
|
||||||
params_no_model = OpenAICreateVectorStoreRequestWithExtraBody(
|
|
||||||
name="test_store2", metadata={}, **{"embedding_model": "test/embedding"}
|
|
||||||
)
|
|
||||||
await vector_io_adapter.openai_create_vector_store(params_no_model)
|
|
||||||
call_args2 = vector_io_adapter.register_vector_store.call_args[0][0]
|
|
||||||
assert call_args2.query_expansion_model is None
|
|
||||||
|
|
||||||
# Test query rewriting scenarios
|
|
||||||
mock_inference_api = MagicMock()
|
|
||||||
|
|
||||||
# Per-store model scenario
|
|
||||||
mock_vector_store = MagicMock()
|
mock_vector_store = MagicMock()
|
||||||
mock_vector_store.query_expansion_model = "test/llama-model"
|
mock_vector_store.embedding_model = "test/embedding"
|
||||||
mock_inference_api.routing_table.list_models = AsyncMock(
|
mock_inference_api = MagicMock()
|
||||||
return_value=MagicMock(
|
mock_index = MagicMock()
|
||||||
data=[Model(identifier="test/llama-model", provider_id="test", model_type=ModelType.llm)]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
mock_inference_api.openai_chat_completion = AsyncMock(
|
|
||||||
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="per-store expanded"))])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Create VectorStoreWithIndex with simplified constructor
|
||||||
vector_store_with_index = VectorStoreWithIndex(
|
vector_store_with_index = VectorStoreWithIndex(
|
||||||
vector_store=mock_vector_store,
|
vector_store=mock_vector_store,
|
||||||
index=MagicMock(),
|
index=mock_index,
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
vector_stores_config=VectorStoresConfig(
|
|
||||||
default_query_expansion_model=QualifiedModel(provider_id="global", model_id="default")
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await vector_store_with_index._rewrite_query_for_search("test")
|
# Mock the query_vector method to return a simple response
|
||||||
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "test/llama-model"
|
mock_response = QueryChunksResponse(chunks=[], scores=[])
|
||||||
assert result == "per-store expanded"
|
mock_index.query_vector = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
# Global default fallback scenario
|
# Mock embeddings generation
|
||||||
mock_inference_api.reset_mock()
|
mock_inference_api.openai_embeddings = AsyncMock(
|
||||||
mock_vector_store.query_expansion_model = None
|
return_value=MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])])
|
||||||
mock_inference_api.routing_table.list_models = AsyncMock(
|
|
||||||
return_value=MagicMock(
|
|
||||||
data=[Model(identifier="global/default", provider_id="global", model_type=ModelType.llm)]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test 1: Query expansion with default prompt (no custom prompt configured)
|
||||||
|
mock_vector_stores_config = MagicMock()
|
||||||
|
mock_vector_stores_config.default_query_expansion_model = QualifiedModel(provider_id="test", model_id="llama")
|
||||||
|
mock_vector_stores_config.query_expansion_prompt = None # Use built-in default prompt
|
||||||
|
mock_vector_stores_config.query_expansion_max_tokens = 100 # Default value
|
||||||
|
mock_vector_stores_config.query_expansion_temperature = 0.3 # Default value
|
||||||
|
|
||||||
|
# Set global config
|
||||||
|
set_default_query_expansion_config(mock_vector_stores_config)
|
||||||
|
|
||||||
|
# Mock chat completion for query rewriting
|
||||||
mock_inference_api.openai_chat_completion = AsyncMock(
|
mock_inference_api.openai_chat_completion = AsyncMock(
|
||||||
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="global expanded"))])
|
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="expanded test query"))])
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await vector_store_with_index._rewrite_query_for_search("test")
|
params = {"rewrite_query": True, "max_chunks": 5}
|
||||||
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "global/default"
|
result = await vector_store_with_index.query_chunks("test query", params)
|
||||||
assert result == "global expanded"
|
|
||||||
|
|
||||||
# Test 3: Error cases
|
# Verify chat completion was called for query rewriting
|
||||||
# Model not found
|
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||||
mock_vector_store.query_expansion_model = "missing/model"
|
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
|
||||||
mock_inference_api.routing_table.list_models = AsyncMock(return_value=MagicMock(data=[]))
|
assert chat_call_args.model == "test/llama"
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Configured query expansion model .* is not available"):
|
# Verify default prompt is used (contains our built-in prompt text)
|
||||||
await vector_store_with_index._rewrite_query_for_search("test")
|
prompt_text = chat_call_args.messages[0].content
|
||||||
|
expected_prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query="test query")
|
||||||
|
assert prompt_text == expected_prompt
|
||||||
|
|
||||||
# Non-LLM model
|
# Verify default inference parameters are used
|
||||||
mock_vector_store.query_expansion_model = "test/embedding-model"
|
assert chat_call_args.max_tokens == 100 # Default value
|
||||||
mock_inference_api.routing_table.list_models = AsyncMock(
|
assert chat_call_args.temperature == 0.3 # Default value
|
||||||
return_value=MagicMock(
|
|
||||||
data=[Model(identifier="test/embedding-model", provider_id="test", model_type=ModelType.embedding)]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="is not an LLM model.*Query rewriting requires an LLM model"):
|
# Verify the rest of the flow proceeded normally
|
||||||
await vector_store_with_index._rewrite_query_for_search("test")
|
mock_inference_api.openai_embeddings.assert_called_once()
|
||||||
|
mock_index.query_vector.assert_called_once()
|
||||||
|
assert result == mock_response
|
||||||
|
|
||||||
|
# Test 1b: Query expansion with custom prompt override and inference parameters
|
||||||
|
mock_inference_api.reset_mock()
|
||||||
|
mock_index.reset_mock()
|
||||||
|
|
||||||
|
mock_vector_stores_config.query_expansion_prompt = "Custom prompt for rewriting: {query}"
|
||||||
|
mock_vector_stores_config.query_expansion_max_tokens = 150
|
||||||
|
mock_vector_stores_config.query_expansion_temperature = 0.7
|
||||||
|
set_default_query_expansion_config(mock_vector_stores_config)
|
||||||
|
|
||||||
|
result = await vector_store_with_index.query_chunks("test query", params)
|
||||||
|
|
||||||
|
# Verify custom prompt and parameters are used
|
||||||
|
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||||
|
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
|
||||||
|
prompt_text = chat_call_args.messages[0].content
|
||||||
|
assert prompt_text == "Custom prompt for rewriting: test query"
|
||||||
|
assert "Expand this query with relevant synonyms" not in prompt_text # Default not used
|
||||||
|
|
||||||
|
# Verify custom inference parameters
|
||||||
|
assert chat_call_args.max_tokens == 150
|
||||||
|
assert chat_call_args.temperature == 0.7
|
||||||
|
|
||||||
|
# Test 2: No query expansion when no global model is configured
|
||||||
|
mock_inference_api.reset_mock()
|
||||||
|
mock_index.reset_mock()
|
||||||
|
|
||||||
|
# Clear global config
|
||||||
|
set_default_query_expansion_config(None)
|
||||||
|
|
||||||
|
params = {"rewrite_query": True, "max_chunks": 5}
|
||||||
|
result2 = await vector_store_with_index.query_chunks("test query", params)
|
||||||
|
|
||||||
|
# Verify chat completion was NOT called
|
||||||
|
mock_inference_api.openai_chat_completion.assert_not_called()
|
||||||
|
# But normal flow should still work
|
||||||
|
mock_inference_api.openai_embeddings.assert_called_once()
|
||||||
|
mock_index.query_vector.assert_called_once()
|
||||||
|
assert result2 == mock_response
|
||||||
|
|
||||||
|
# Test 3: Normal behavior without rewrite_query parameter
|
||||||
|
mock_inference_api.reset_mock()
|
||||||
|
mock_index.reset_mock()
|
||||||
|
|
||||||
|
params_no_rewrite = {"max_chunks": 5}
|
||||||
|
result3 = await vector_store_with_index.query_chunks("test query", params_no_rewrite)
|
||||||
|
|
||||||
|
# Neither chat completion nor query rewriting should be called
|
||||||
|
mock_inference_api.openai_chat_completion.assert_not_called()
|
||||||
|
mock_inference_api.openai_embeddings.assert_called_once()
|
||||||
|
mock_index.query_vector.assert_called_once()
|
||||||
|
assert result3 == mock_response
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue