Merge branch 'main' into fix/issue-2584-llama4-tool-calling

This commit is contained in:
Sumanth Kamenani 2025-07-15 09:12:15 -04:00 committed by GitHub
commit 5679d4dfd6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 669 additions and 507 deletions

View file

@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id):
return agent_config
@pytest.fixture(scope="session")
def agent_config_without_safety(text_model_id):
agent_config = dict(
model=text_model_id,
instructions="You are a helpful assistant",
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 0.0001,
"top_p": 0.9,
},
},
tools=[],
enable_session_persistence=False,
)
return agent_config
def test_agent_simple(llama_stack_client, agent_config):
agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
assert expected_kw in response.output_message.content.lower()
def test_rag_agent_with_attachments(llama_stack_client, agent_config):
def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
urls = ["llama3.rst", "lora_finetune.rst"]
documents = [
# passign as url
@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
metadata={},
),
]
rag_agent = Agent(llama_stack_client, **agent_config)
rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
]
user_prompts = [
(
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
assert "lora" in response.output_message.content.lower()
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
def test_rag_and_code_agent(llama_stack_client, agent_config):
if "llama-4" in agent_config["model"].lower():
pytest.xfail("Not working for llama4")
documents = []
documents.append(
Document(
document_id="nba_wiki",
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
metadata={},
)
)
documents.append(
Document(
document_id="perplexity_wiki",
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
Srinivas, the CEO, worked at OpenAI as an AI researcher.
Konwinski was among the founding team at Databricks.
Yarats, the CTO, was an AI research scientist at Meta.
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
metadata={},
)
)
vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=128,
)
agent_config = {
**agent_config,
"tools": [
dict(
name="builtin::rag/knowledge_search",
args={"vector_db_ids": [vector_db_id]},
),
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client, **agent_config)
user_prompts = [
(
"when was Perplexity the company founded?",
[],
"knowledge_search",
"2022",
),
(
"when was the nba created?",
[],
"knowledge_search",
"1949",
),
]
for prompt, docs, tool_name, expected_kw in user_prompts:
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
documents=docs,
stream=False,
)
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
if expected_kw:
assert expected_kw in response.output_message.content.lower()
@pytest.mark.parametrize(
"client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],

View file

@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
# Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus"
@pytest_asyncio.fixture
async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
return client
@pytest_asyncio.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
yield index
# No real cleanup needed since we're using mocks
@pytest.mark.asyncio
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Verify collection was created and data was inserted
mock_milvus_client.create_collection.assert_called_once()
mock_milvus_client.insert.assert_called_once()
# Verify the insert call had the right number of chunks
insert_call = mock_milvus_client.insert.call_args
assert len(insert_call[1]["data"]) == len(sample_chunks)
@pytest.mark.asyncio
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_milvus_client.search.assert_called_once()
@pytest.mark.asyncio
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
# Mock simple text search results
mock_milvus_client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
},
]
# Test keyword search that should fall back to simple text search
query_string = "Python"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
# Verify response structure
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0, "Fallback search should return results"
# Verify that simple text search was used (query method called instead of search)
mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query uses parameterized filter with filter_params
query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
@pytest.mark.asyncio
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)