Merge branch 'meta-llama:main' into main

This commit is contained in:
Chacksu 2025-01-09 12:34:38 -05:00 committed by GitHub
commit 0ac4d2fced
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
119 changed files with 5069 additions and 2779 deletions

View file

@ -7,13 +7,12 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from .fixtures import AGENTS_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
@ -21,6 +20,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
@ -31,6 +31,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
id="ollama",
marks=pytest.mark.ollama,
@ -42,6 +43,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
# make this work with Weaviate which is what the together distro supports
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
id="together",
marks=pytest.mark.together,
@ -52,6 +54,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
id="fireworks",
marks=pytest.mark.fireworks,
@ -62,6 +65,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "remote",
"memory": "remote",
"agents": "remote",
"tool_runtime": "memory_and_search",
},
id="remote",
marks=pytest.mark.remote,
@ -117,6 +121,7 @@ def pytest_generate_tests(metafunc):
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
"agents": AGENTS_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)

View file

@ -11,13 +11,12 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
@ -59,12 +58,18 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(request, inference_model, safety_shield):
async def agents_stack(
request,
inference_model,
safety_shield,
tool_group_input_memory,
tool_group_input_tavily_search,
):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "safety", "memory", "agents"]:
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
@ -113,10 +118,11 @@ async def agents_stack(request, inference_model, safety_shield):
)
test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory],
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
providers,
provider_data,
models=models,
shields=[safety_shield] if safety_shield else [],
tool_groups=[tool_group_input_memory, tool_group_input_tavily_search],
)
return test_stack

View file

@ -5,22 +5,17 @@
# the root directory of this source tree.
import os
from typing import Dict, List
import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentTurnResponseEventType,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload,
Attachment,
MemoryToolDefinition,
SearchEngineType,
SearchToolDefinition,
Document,
ShieldCallStep,
StepType,
ToolChoice,
@ -35,7 +30,6 @@ from llama_stack.providers.datatypes import Api
#
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
# -m "meta_reference"
from .fixtures import pick_inference_model
from .utils import create_agent_session
@ -51,7 +45,7 @@ def common_params(inference_model):
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
toolgroups=[],
max_infer_iters=5,
)
@ -88,73 +82,6 @@ def query_attachment_messages():
]
async def create_agent_turn_with_search_tool(
agents_stack: Dict[str, object],
search_query_messages: List[object],
common_params: Dict[str, str],
search_tool_definition: SearchToolDefinition,
) -> None:
"""
Create an agent turn with a search tool.
Args:
agents_stack (Dict[str, object]): The agents stack.
search_query_messages (List[object]): The search query messages.
common_params (Dict[str, str]): The common parameters.
search_tool_definition (SearchToolDefinition): The search tool definition.
"""
# Create an agent with the search tool
agent_config = AgentConfig(
**{
**common_params,
"tools": [search_tool_definition],
}
)
agent_id, session_id = await create_agent_session(
agents_stack.impls[Api.agents], agent_config
)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
**turn_request
)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
class TestAgents:
@pytest.mark.asyncio
async def test_agent_turns_with_safety(
@ -227,7 +154,7 @@ class TestAgents:
check_turn_complete_event(turn_response, session_id, sample_messages)
@pytest.mark.asyncio
async def test_rag_agent_as_attachments(
async def test_rag_agent(
self,
agents_stack,
attachment_message,
@ -243,29 +170,17 @@ class TestAgents:
"qat_finetune.rst",
"lora_finetune.rst",
]
attachments = [
Attachment(
documents = [
Document(
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
agent_config = AgentConfig(
**{
**common_params,
"tools": [
MemoryToolDefinition(
memory_bank_configs=[],
query_generator_config={
"type": "default",
"sep": " ",
},
max_tokens_in_context=4096,
max_chunks=10,
),
],
"toolgroups": ["builtin::memory"],
"tool_choice": ToolChoice.auto,
}
)
@ -275,7 +190,7 @@ class TestAgents:
agent_id=agent_id,
session_id=session_id,
messages=attachment_message,
attachments=attachments,
documents=documents,
stream=True,
)
turn_response = [
@ -298,22 +213,6 @@ class TestAgents:
assert len(turn_response) > 0
@pytest.mark.asyncio
async def test_create_agent_turn_with_brave_search(
self, agents_stack, search_query_messages, common_params
):
if "BRAVE_SEARCH_API_KEY" not in os.environ:
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
search_tool_definition = SearchToolDefinition(
type=AgentTool.brave_search.value,
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
engine=SearchEngineType.brave,
)
await create_agent_turn_with_search_tool(
agents_stack, search_query_messages, common_params, search_tool_definition
)
@pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params
@ -321,14 +220,57 @@ class TestAgents:
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
search_tool_definition = SearchToolDefinition(
type=AgentTool.brave_search.value, # place holder only
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
engine=SearchEngineType.tavily,
# Create an agent with the toolgroup
agent_config = AgentConfig(
**{
**common_params,
"toolgroups": ["builtin::web_search"],
}
)
await create_agent_turn_with_search_tool(
agents_stack, search_query_messages, common_params, search_tool_definition
agent_id, session_id = await create_agent_session(
agents_stack.impls[Api.agents], agent_config
)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
**turn_request
)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type
== StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
actual_tool_name = tool_execution.tool_calls[0].tool_name
assert actual_tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
def check_event_types(turn_response):