more test fixes

This commit is contained in:
Dinesh Yeduguru 2024-12-30 13:27:43 -08:00
parent 40439509ca
commit c2dd0cdc78
7 changed files with 63 additions and 120 deletions

1
.gitignore vendored
View file

@ -19,4 +19,3 @@ Package.resolved
_build
docs/src
pyrightconfig.json
.aider*

View file

@ -10,7 +10,8 @@ from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from .fixtures import AGENTS_FIXTURES, TOOL_RUNTIME_FIXTURES
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from .fixtures import AGENTS_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(

View file

@ -4,21 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import pytest
import pytest_asyncio
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import (
BuiltInToolDef,
CustomToolDef,
ToolGroupInput,
ToolParameter,
UserDefinedToolGroupDef,
)
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig,
@ -63,32 +54,17 @@ def agents_meta_reference() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def tool_runtime_memory_and_search() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="memory-runtime",
provider_type="inline::memory-runtime",
config={},
),
Provider(
provider_id="tavily-search",
provider_type="remote::tavily-search",
config={
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
},
),
],
)
AGENTS_FIXTURES = ["meta_reference", "remote"]
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(request, inference_model, safety_shield):
async def agents_stack(
request,
inference_model,
safety_shield,
tool_group_input_memory,
tool_group_input_tavily_search,
):
fixture_dict = request.param
providers = {}
@ -140,47 +116,6 @@ async def agents_stack(request, inference_model, safety_shield):
metadata={"embedding_dimension": 384},
)
)
tool_groups = [
ToolGroupInput(
tool_group_id="tavily_search_group",
tool_group=UserDefinedToolGroupDef(
tools=[
BuiltInToolDef(
built_in_type=BuiltinTool.brave_search,
metadata={},
),
],
),
provider_id="tavily-search",
),
ToolGroupInput(
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
tools=[
CustomToolDef(
name="memory",
description="memory",
parameters=[
ToolParameter(
name="input_messages",
description="messages",
parameter_type="list",
required=True,
),
],
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
provider_id="memory-runtime",
),
]
test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
@ -188,6 +123,6 @@ async def agents_stack(request, inference_model, safety_shield):
provider_data,
models=models,
shields=[safety_shield] if safety_shield else [],
tool_groups=tool_groups,
tool_groups=[tool_group_input_memory, tool_group_input_tavily_search],
)
return test_stack

View file

@ -22,6 +22,8 @@ from llama_stack.apis.agents import (
Turn,
)
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory_banks import VectorMemoryBankParams
from llama_stack.apis.safety import ViolationLevel
from llama_stack.providers.datatypes import Api

View file

@ -157,4 +157,5 @@ pytest_plugins = [
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
"llama_stack.providers.tests.post_training.fixtures",
"llama_stack.providers.tests.tools.fixtures",
]

View file

@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from .fixtures import TOOL_RUNTIME_FIXTURES, tools_stack # noqa: F401
from .fixtures import TOOL_RUNTIME_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(

View file

@ -44,11 +44,55 @@ def tool_runtime_memory_and_search() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def tool_group_input_memory() -> ToolGroupInput:
return ToolGroupInput(
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
tools=[
CustomToolDef(
name="memory",
description="Query the memory bank",
parameters=[
ToolParameter(
name="input_messages",
description="The input messages to search for in memory",
parameter_type="list",
required=True,
),
],
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
provider_id="memory-runtime",
)
@pytest.fixture(scope="session")
def tool_group_input_tavily_search() -> ToolGroupInput:
return ToolGroupInput(
tool_group_id="tavily_search_group",
tool_group=UserDefinedToolGroupDef(
tools=[BuiltInToolDef(built_in_type=BuiltinTool.brave_search, metadata={})],
),
provider_id="tavily-search",
)
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session")
async def tools_stack(request, inference_model):
async def tools_stack(
request, inference_model, tool_group_input_memory, tool_group_input_tavily_search
):
fixture_dict = request.param
providers = {}
@ -86,53 +130,14 @@ async def tools_stack(request, inference_model):
)
)
tool_groups = [
ToolGroupInput(
tool_group_id="tavily_search_group",
tool_group=UserDefinedToolGroupDef(
tools=[
BuiltInToolDef(
built_in_type=BuiltinTool.brave_search,
metadata={},
),
],
),
provider_id="tavily-search",
),
ToolGroupInput(
tool_group_id="memory_group",
tool_group=UserDefinedToolGroupDef(
tools=[
CustomToolDef(
name="memory",
description="Query the memory bank",
parameters=[
ToolParameter(
name="input_messages",
description="The input messages to search for in memory",
parameter_type="list",
required=True,
),
],
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
provider_id="memory-runtime",
),
]
test_stack = await construct_stack_for_test(
[Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime],
providers,
provider_data,
models=models,
tool_groups=tool_groups,
tool_groups=[
tool_group_input_tavily_search,
tool_group_input_memory,
],
)
return test_stack