mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
refactor to only configuration of model at build time
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
2cc7943fd6
commit
d887f1f8bb
31 changed files with 280 additions and 315 deletions
|
|
@ -1233,94 +1233,122 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
|
|||
|
||||
|
||||
async def test_query_expansion_functionality(vector_io_adapter):
|
||||
"""Test query expansion with per-store models, global defaults, and error validation."""
|
||||
"""Test query expansion with simplified global configuration approach."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
|
||||
from llama_stack.core.datatypes import QualifiedModel
|
||||
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
||||
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
|
||||
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
|
||||
from llama_stack_api.models import Model, ModelType
|
||||
from llama_stack_api import QueryChunksResponse
|
||||
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test 1: Per-store model usage
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={},
|
||||
**{"embedding_model": "test/embedding", "query_expansion_model": "test/llama-model"},
|
||||
)
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.query_expansion_model == "test/llama-model"
|
||||
|
||||
# Test 2: Global default fallback
|
||||
vector_io_adapter.register_vector_store.reset_mock()
|
||||
params_no_model = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store2", metadata={}, **{"embedding_model": "test/embedding"}
|
||||
)
|
||||
await vector_io_adapter.openai_create_vector_store(params_no_model)
|
||||
call_args2 = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args2.query_expansion_model is None
|
||||
|
||||
# Test query rewriting scenarios
|
||||
mock_inference_api = MagicMock()
|
||||
|
||||
# Per-store model scenario
|
||||
# Mock a simple vector store and index
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.query_expansion_model = "test/llama-model"
|
||||
mock_inference_api.routing_table.list_models = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
data=[Model(identifier="test/llama-model", provider_id="test", model_type=ModelType.llm)]
|
||||
)
|
||||
)
|
||||
mock_inference_api.openai_chat_completion = AsyncMock(
|
||||
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="per-store expanded"))])
|
||||
)
|
||||
mock_vector_store.embedding_model = "test/embedding"
|
||||
mock_inference_api = MagicMock()
|
||||
mock_index = MagicMock()
|
||||
|
||||
# Create VectorStoreWithIndex with simplified constructor
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store,
|
||||
index=MagicMock(),
|
||||
index=mock_index,
|
||||
inference_api=mock_inference_api,
|
||||
vector_stores_config=VectorStoresConfig(
|
||||
default_query_expansion_model=QualifiedModel(provider_id="global", model_id="default")
|
||||
),
|
||||
)
|
||||
|
||||
result = await vector_store_with_index._rewrite_query_for_search("test")
|
||||
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "test/llama-model"
|
||||
assert result == "per-store expanded"
|
||||
# Mock the query_vector method to return a simple response
|
||||
mock_response = QueryChunksResponse(chunks=[], scores=[])
|
||||
mock_index.query_vector = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Global default fallback scenario
|
||||
mock_inference_api.reset_mock()
|
||||
mock_vector_store.query_expansion_model = None
|
||||
mock_inference_api.routing_table.list_models = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
data=[Model(identifier="global/default", provider_id="global", model_type=ModelType.llm)]
|
||||
)
|
||||
# Mock embeddings generation
|
||||
mock_inference_api.openai_embeddings = AsyncMock(
|
||||
return_value=MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])])
|
||||
)
|
||||
|
||||
# Test 1: Query expansion with default prompt (no custom prompt configured)
|
||||
mock_vector_stores_config = MagicMock()
|
||||
mock_vector_stores_config.default_query_expansion_model = QualifiedModel(provider_id="test", model_id="llama")
|
||||
mock_vector_stores_config.query_expansion_prompt = None # Use built-in default prompt
|
||||
mock_vector_stores_config.query_expansion_max_tokens = 100 # Default value
|
||||
mock_vector_stores_config.query_expansion_temperature = 0.3 # Default value
|
||||
|
||||
# Set global config
|
||||
set_default_query_expansion_config(mock_vector_stores_config)
|
||||
|
||||
# Mock chat completion for query rewriting
|
||||
mock_inference_api.openai_chat_completion = AsyncMock(
|
||||
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="global expanded"))])
|
||||
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="expanded test query"))])
|
||||
)
|
||||
|
||||
result = await vector_store_with_index._rewrite_query_for_search("test")
|
||||
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "global/default"
|
||||
assert result == "global expanded"
|
||||
params = {"rewrite_query": True, "max_chunks": 5}
|
||||
result = await vector_store_with_index.query_chunks("test query", params)
|
||||
|
||||
# Test 3: Error cases
|
||||
# Model not found
|
||||
mock_vector_store.query_expansion_model = "missing/model"
|
||||
mock_inference_api.routing_table.list_models = AsyncMock(return_value=MagicMock(data=[]))
|
||||
# Verify chat completion was called for query rewriting
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
|
||||
assert chat_call_args.model == "test/llama"
|
||||
|
||||
with pytest.raises(ValueError, match="Configured query expansion model .* is not available"):
|
||||
await vector_store_with_index._rewrite_query_for_search("test")
|
||||
# Verify default prompt is used (contains our built-in prompt text)
|
||||
prompt_text = chat_call_args.messages[0].content
|
||||
expected_prompt = DEFAULT_QUERY_EXPANSION_PROMPT.format(query="test query")
|
||||
assert prompt_text == expected_prompt
|
||||
|
||||
# Non-LLM model
|
||||
mock_vector_store.query_expansion_model = "test/embedding-model"
|
||||
mock_inference_api.routing_table.list_models = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
data=[Model(identifier="test/embedding-model", provider_id="test", model_type=ModelType.embedding)]
|
||||
)
|
||||
)
|
||||
# Verify default inference parameters are used
|
||||
assert chat_call_args.max_tokens == 100 # Default value
|
||||
assert chat_call_args.temperature == 0.3 # Default value
|
||||
|
||||
with pytest.raises(ValueError, match="is not an LLM model.*Query rewriting requires an LLM model"):
|
||||
await vector_store_with_index._rewrite_query_for_search("test")
|
||||
# Verify the rest of the flow proceeded normally
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
mock_index.query_vector.assert_called_once()
|
||||
assert result == mock_response
|
||||
|
||||
# Test 1b: Query expansion with custom prompt override and inference parameters
|
||||
mock_inference_api.reset_mock()
|
||||
mock_index.reset_mock()
|
||||
|
||||
mock_vector_stores_config.query_expansion_prompt = "Custom prompt for rewriting: {query}"
|
||||
mock_vector_stores_config.query_expansion_max_tokens = 150
|
||||
mock_vector_stores_config.query_expansion_temperature = 0.7
|
||||
set_default_query_expansion_config(mock_vector_stores_config)
|
||||
|
||||
result = await vector_store_with_index.query_chunks("test query", params)
|
||||
|
||||
# Verify custom prompt and parameters are used
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
|
||||
prompt_text = chat_call_args.messages[0].content
|
||||
assert prompt_text == "Custom prompt for rewriting: test query"
|
||||
assert "Expand this query with relevant synonyms" not in prompt_text # Default not used
|
||||
|
||||
# Verify custom inference parameters
|
||||
assert chat_call_args.max_tokens == 150
|
||||
assert chat_call_args.temperature == 0.7
|
||||
|
||||
# Test 2: No query expansion when no global model is configured
|
||||
mock_inference_api.reset_mock()
|
||||
mock_index.reset_mock()
|
||||
|
||||
# Clear global config
|
||||
set_default_query_expansion_config(None)
|
||||
|
||||
params = {"rewrite_query": True, "max_chunks": 5}
|
||||
result2 = await vector_store_with_index.query_chunks("test query", params)
|
||||
|
||||
# Verify chat completion was NOT called
|
||||
mock_inference_api.openai_chat_completion.assert_not_called()
|
||||
# But normal flow should still work
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
mock_index.query_vector.assert_called_once()
|
||||
assert result2 == mock_response
|
||||
|
||||
# Test 3: Normal behavior without rewrite_query parameter
|
||||
mock_inference_api.reset_mock()
|
||||
mock_index.reset_mock()
|
||||
|
||||
params_no_rewrite = {"max_chunks": 5}
|
||||
result3 = await vector_store_with_index.query_chunks("test query", params_no_rewrite)
|
||||
|
||||
# Neither chat completion nor query rewriting should be called
|
||||
mock_inference_api.openai_chat_completion.assert_not_called()
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
mock_index.query_vector.assert_called_once()
|
||||
assert result3 == mock_response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue