Merge branch 'main' into update-completions-skipping-for-groq

This commit is contained in:
raghotham 2025-09-06 12:36:17 -07:00 committed by GitHub
commit 3ebf5bd407
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 134 additions and 43 deletions

View file

@ -40,6 +40,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
# {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos,
# or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}}
"remote::groq",
"remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
@ -66,6 +67,12 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
if provider.provider_type in (
"remote::sambanova",
"remote::ollama",
# https://console.groq.com/docs/openai#currently-unsupported-openai-features
# -> Error code: 400 - {'error': {'message': "'n' : number must be at most 1", 'type': 'invalid_request_error'}}
"remote::groq",
# Error code: 400 - [{'error': {'code': 400, 'message': 'Only one candidate can be specified in the
# current model', 'status': 'INVALID_ARGUMENT'}}]
"remote::gemini",
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")

View file

@ -17,10 +17,14 @@ def client_with_empty_registry(client_with_models):
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
clear_registry()
try:
client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime")
except Exception:
pass
yield client_with_models
# you must clean after the last test if you were running tests against
# a stateful server instance
clear_registry()
@ -66,12 +70,13 @@ def assert_valid_text_response(response):
def test_vector_db_insert_inline_and_query(
client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension
):
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
vector_db_name = "test_vector_db"
vector_db = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_name,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
vector_db_id = vector_db.identifier
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=sample_documents,
@ -134,7 +139,11 @@ def test_vector_db_insert_from_url_and_query(
# list to check memory bank is successfully registered
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_db_id in available_vector_dbs
# VectorDB is being migrated to VectorStore, so the ID will be different
# Just check that at least one vector DB was registered
assert len(available_vector_dbs) > 0
# Use the actual registered vector_db_id for subsequent operations
actual_vector_db_id = available_vector_dbs[0]
urls = [
"memory_optimizations.rst",
@ -153,13 +162,13 @@ def test_vector_db_insert_from_url_and_query(
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=512,
)
# Query for the name of method
response1 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
query="What's the name of the fine-tunning method used?",
)
assert_valid_chunk_response(response1)
@ -167,7 +176,7 @@ def test_vector_db_insert_from_url_and_query(
# Query for the name of model
response2 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
query="Which Llama model is mentioned?",
)
assert_valid_chunk_response(response2)
@ -187,7 +196,11 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_db_id in available_vector_dbs
# VectorDB is being migrated to VectorStore, so the ID will be different
# Just check that at least one vector DB was registered
assert len(available_vector_dbs) > 0
# Use the actual registered vector_db_id for subsequent operations
actual_vector_db_id = available_vector_dbs[0]
urls = [
"memory_optimizations.rst",
@ -206,19 +219,19 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=512,
)
response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
vector_db_ids=[actual_vector_db_id],
content="What is the name of the method used for fine-tuning?",
)
assert_valid_text_response(response_with_metadata)
assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
vector_db_ids=[actual_vector_db_id],
content="What is the name of the method used for fine-tuning?",
query_config={
"include_metadata_in_content": True,
@ -230,7 +243,7 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
with pytest.raises((ValueError, BadRequestError)):
client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
vector_db_ids=[actual_vector_db_id],
content="What is the name of the method used for fine-tuning?",
query_config={
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",

View file

@ -19,12 +19,16 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
class TestRagQuery:
async def test_query_raises_on_empty_vector_db_ids(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
async def test_query_chunk_metadata_handling(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
content = "test query content"
vector_db_ids = ["db1"]