mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
undoing formatting and updating missed expansion parameterS
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
88ce118ba3
commit
2ebc56c3d9
13 changed files with 23 additions and 61 deletions
|
|
@ -18,7 +18,7 @@ from llama_stack.core.storage.datatypes import (
|
||||||
StorageConfig,
|
StorageConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.log import LoggingConfig
|
from llama_stack.log import LoggingConfig
|
||||||
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Api,
|
Api,
|
||||||
Benchmark,
|
Benchmark,
|
||||||
|
|
@ -374,7 +374,7 @@ class RewriteQueryParams(BaseModel):
|
||||||
description="LLM model for query rewriting/expansion in vector search.",
|
description="LLM model for query rewriting/expansion in vector search.",
|
||||||
)
|
)
|
||||||
prompt: str = Field(
|
prompt: str = Field(
|
||||||
default=DEFAULT_QUERY_EXPANSION_PROMPT,
|
default=DEFAULT_QUERY_REWRITE_PROMPT,
|
||||||
description="Prompt template for query rewriting. Use {query} as placeholder for the original query.",
|
description="Prompt template for query rewriting. Use {query} as placeholder for the original query.",
|
||||||
)
|
)
|
||||||
max_tokens: int = Field(
|
max_tokens: int = Field(
|
||||||
|
|
|
||||||
|
|
@ -184,12 +184,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||||
def __init__(
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||||
self,
|
|
||||||
config: FaissVectorIOConfig,
|
|
||||||
inference_api: Inference,
|
|
||||||
files_api: Files | None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
|
|
||||||
|
|
@ -385,12 +385,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
||||||
and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
|
and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||||
self,
|
|
||||||
config,
|
|
||||||
inference_api: Inference,
|
|
||||||
files_api: Files | None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,7 @@ from llama_stack_api import Api, ProviderSpec
|
||||||
from .config import PGVectorVectorIOConfig
|
from .config import PGVectorVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
config: PGVectorVectorIOConfig,
|
|
||||||
deps: dict[Api, ProviderSpec],
|
|
||||||
):
|
|
||||||
from .pgvector import PGVectorVectorIOAdapter
|
from .pgvector import PGVectorVectorIOAdapter
|
||||||
|
|
||||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
|
|
|
||||||
|
|
@ -330,10 +330,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
|
|
||||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
|
||||||
config: PGVectorVectorIOConfig,
|
|
||||||
inference_api: Inference,
|
|
||||||
files_api: Files | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
@ -389,11 +386,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
||||||
kvstore=self.kvstore,
|
kvstore=self.kvstore,
|
||||||
)
|
)
|
||||||
await pgvector_index.initialize()
|
await pgvector_index.initialize()
|
||||||
index = VectorStoreWithIndex(
|
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
|
||||||
vector_store,
|
|
||||||
index=pgvector_index,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
)
|
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
|
@ -420,11 +413,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
||||||
vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||||
)
|
)
|
||||||
await pgvector_index.initialize()
|
await pgvector_index.initialize()
|
||||||
index = VectorStoreWithIndex(
|
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
|
||||||
vector_store,
|
|
||||||
index=pgvector_index,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
)
|
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
|
|
||||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -173,9 +173,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
for vector_store_data in stored_vector_stores:
|
for vector_store_data in stored_vector_stores:
|
||||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||||
index = VectorStoreWithIndex(
|
index = VectorStoreWithIndex(
|
||||||
vector_store,
|
vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api
|
||||||
QdrantIndex(self.client, vector_store.identifier),
|
|
||||||
self.inference_api,
|
|
||||||
)
|
)
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,7 @@ from llama_stack_api import Api, ProviderSpec
|
||||||
from .config import WeaviateVectorIOConfig
|
from .config import WeaviateVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
config: WeaviateVectorIOConfig,
|
|
||||||
deps: dict[Api, ProviderSpec],
|
|
||||||
):
|
|
||||||
from .weaviate import WeaviateVectorIOAdapter
|
from .weaviate import WeaviateVectorIOAdapter
|
||||||
|
|
||||||
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||||
|
|
|
||||||
|
|
@ -262,12 +262,7 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate):
|
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate):
|
||||||
def __init__(
|
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||||
self,
|
|
||||||
config: WeaviateVectorIOConfig,
|
|
||||||
inference_api: Inference,
|
|
||||||
files_api: Files | None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(files_api=files_api, kvstore=None)
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
|
@ -313,9 +308,7 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore)
|
idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore)
|
||||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||||
vector_store=vector_store,
|
vector_store=vector_store, index=idx, inference_api=self.inference_api
|
||||||
index=idx,
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load OpenAI vector stores metadata into cache
|
# Load OpenAI vector stores metadata into cache
|
||||||
|
|
@ -341,9 +334,7 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||||
vector_store,
|
vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
|
||||||
WeaviateIndex(client=client, collection_name=sanitized_collection_name),
|
|
||||||
self.inference_api,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
from .constants import DEFAULT_QUERY_REWRITE_PROMPT
|
||||||
|
|
||||||
__all__ = ["DEFAULT_QUERY_EXPANSION_PROMPT"]
|
__all__ = ["DEFAULT_QUERY_REWRITE_PROMPT"]
|
||||||
|
|
|
||||||
|
|
@ -4,5 +4,5 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Default prompt template for query expansion in vector search
|
# Default prompt template for query rewriting in vector search
|
||||||
DEFAULT_QUERY_EXPANSION_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:"
|
DEFAULT_QUERY_REWRITE_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:"
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
|
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
|
||||||
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
|
||||||
|
|
||||||
# Global configuration for query rewriting - set during stack startup
|
# Global configuration for query rewriting - set during stack startup
|
||||||
_DEFAULT_REWRITE_QUERY_MODEL: QualifiedModel | None = None
|
_DEFAULT_REWRITE_QUERY_MODEL: QualifiedModel | None = None
|
||||||
|
|
@ -25,7 +25,7 @@ def set_default_rewrite_query_config(vector_stores_config: VectorStoresConfig |
|
||||||
params = vector_stores_config.rewrite_query_params
|
params = vector_stores_config.rewrite_query_params
|
||||||
_DEFAULT_REWRITE_QUERY_MODEL = params.model
|
_DEFAULT_REWRITE_QUERY_MODEL = params.model
|
||||||
# Only set override if user provided a custom prompt different from default
|
# Only set override if user provided a custom prompt different from default
|
||||||
if params.prompt != DEFAULT_QUERY_EXPANSION_PROMPT:
|
if params.prompt != DEFAULT_QUERY_REWRITE_PROMPT:
|
||||||
_REWRITE_QUERY_PROMPT_OVERRIDE = params.prompt
|
_REWRITE_QUERY_PROMPT_OVERRIDE = params.prompt
|
||||||
else:
|
else:
|
||||||
_REWRITE_QUERY_PROMPT_OVERRIDE = None
|
_REWRITE_QUERY_PROMPT_OVERRIDE = None
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ from llama_stack_api import (
|
||||||
log = get_logger(name=__name__, category="providers::utils")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory import rewrite_query_config
|
from llama_stack.providers.utils.memory import rewrite_query_config
|
||||||
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
|
||||||
|
|
||||||
|
|
||||||
class ChunkForDeletion(BaseModel):
|
class ChunkForDeletion(BaseModel):
|
||||||
|
|
@ -312,7 +312,7 @@ class VectorStoreWithIndex:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use built-in default prompt and format with query
|
# Use built-in default prompt and format with query
|
||||||
prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query=query)
|
prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query=query)
|
||||||
|
|
||||||
request = OpenAIChatCompletionRequestWithExtraBody(
|
request = OpenAIChatCompletionRequestWithExtraBody(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
|
|
|
||||||
|
|
@ -1237,7 +1237,7 @@ async def test_query_expansion_functionality(vector_io_adapter):
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams
|
from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams
|
||||||
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
|
||||||
from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config
|
from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config
|
||||||
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
|
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
|
||||||
from llama_stack_api import QueryChunksResponse
|
from llama_stack_api import QueryChunksResponse
|
||||||
|
|
@ -1288,7 +1288,7 @@ async def test_query_expansion_functionality(vector_io_adapter):
|
||||||
|
|
||||||
# Verify default prompt is used (contains our built-in prompt text)
|
# Verify default prompt is used (contains our built-in prompt text)
|
||||||
prompt_text = chat_call_args.messages[0].content
|
prompt_text = chat_call_args.messages[0].content
|
||||||
expected_prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query="test query")
|
expected_prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query="test query")
|
||||||
assert prompt_text == expected_prompt
|
assert prompt_text == expected_prompt
|
||||||
|
|
||||||
# Verify default inference parameters are used
|
# Verify default inference parameters are used
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue