mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
161 lines
4.9 KiB
Python
161 lines
4.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import pytest
|
|
from dotenv import load_dotenv
|
|
from pydantic import BaseModel
|
|
from termcolor import colored
|
|
|
|
from llama_stack.distribution.datatypes import Provider
|
|
from llama_stack.providers.datatypes import RemoteProviderConfig
|
|
|
|
from .env import get_env_or_fail
|
|
|
|
|
|
class ProviderFixture(BaseModel):
|
|
providers: List[Provider]
|
|
provider_data: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
def remote_stack_fixture() -> ProviderFixture:
|
|
if url := os.getenv("REMOTE_STACK_URL", None):
|
|
config = RemoteProviderConfig.from_url(url)
|
|
else:
|
|
config = RemoteProviderConfig(
|
|
host=get_env_or_fail("REMOTE_STACK_HOST"),
|
|
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
|
|
)
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="test::remote",
|
|
provider_type="test::remote",
|
|
config=config.model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.option.tbstyle = "short"
|
|
config.option.disable_warnings = True
|
|
|
|
"""Load environment variables at start of test run"""
|
|
# Load from .env file if it exists
|
|
env_file = Path(__file__).parent / ".env"
|
|
if env_file.exists():
|
|
load_dotenv(env_file)
|
|
|
|
# Load any environment variables passed via --env
|
|
env_vars = config.getoption("--env") or []
|
|
for env_var in env_vars:
|
|
key, value = env_var.split("=", 1)
|
|
os.environ[key] = value
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
parser.addoption(
|
|
"--providers",
|
|
default="",
|
|
help=(
|
|
"Provider configuration in format: api1=provider1,api2=provider2. "
|
|
"Example: --providers inference=ollama,safety=meta-reference"
|
|
),
|
|
)
|
|
"""Add custom command line options"""
|
|
parser.addoption(
|
|
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
|
|
)
|
|
|
|
|
|
def make_provider_id(providers: Dict[str, str]) -> str:
|
|
return ":".join(f"{api}={provider}" for api, provider in sorted(providers.items()))
|
|
|
|
|
|
def get_provider_marks(providers: Dict[str, str]) -> List[Any]:
|
|
marks = []
|
|
for provider in providers.values():
|
|
marks.append(getattr(pytest.mark, provider))
|
|
return marks
|
|
|
|
|
|
def get_provider_fixture_overrides(
|
|
config, available_fixtures: Dict[str, List[str]]
|
|
) -> Optional[List[pytest.param]]:
|
|
provider_str = config.getoption("--providers")
|
|
if not provider_str:
|
|
return None
|
|
|
|
fixture_dict = parse_fixture_string(provider_str, available_fixtures)
|
|
return [
|
|
pytest.param(
|
|
fixture_dict,
|
|
id=make_provider_id(fixture_dict),
|
|
marks=get_provider_marks(fixture_dict),
|
|
)
|
|
]
|
|
|
|
|
|
def parse_fixture_string(
|
|
provider_str: str, available_fixtures: Dict[str, List[str]]
|
|
) -> Dict[str, str]:
|
|
"""Parse provider string of format 'api1=provider1,api2=provider2'"""
|
|
if not provider_str:
|
|
return {}
|
|
|
|
fixtures = {}
|
|
pairs = provider_str.split(",")
|
|
for pair in pairs:
|
|
if "=" not in pair:
|
|
raise ValueError(
|
|
f"Invalid provider specification: {pair}. Expected format: api=provider"
|
|
)
|
|
api, fixture = pair.split("=")
|
|
if api not in available_fixtures:
|
|
raise ValueError(
|
|
f"Unknown API: {api}. Available APIs: {list(available_fixtures.keys())}"
|
|
)
|
|
if fixture not in available_fixtures[api]:
|
|
raise ValueError(
|
|
f"Unknown provider '{fixture}' for API '{api}'. "
|
|
f"Available providers: {list(available_fixtures[api])}"
|
|
)
|
|
fixtures[api] = fixture
|
|
|
|
# Check that all provided APIs are supported
|
|
for api in available_fixtures.keys():
|
|
if api not in fixtures:
|
|
raise ValueError(
|
|
f"Missing provider fixture for API '{api}'. Available providers: "
|
|
f"{list(available_fixtures[api])}"
|
|
)
|
|
return fixtures
|
|
|
|
|
|
def pytest_itemcollected(item):
|
|
# Get all markers as a list
|
|
filtered = ("asyncio", "parametrize")
|
|
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
|
|
if marks:
|
|
marks = colored(",".join(marks), "yellow")
|
|
item.name = f"{item.name}[{marks}]"
|
|
|
|
|
|
pytest_plugins = [
|
|
"llama_stack.providers.tests.inference.fixtures",
|
|
"llama_stack.providers.tests.safety.fixtures",
|
|
"llama_stack.providers.tests.memory.fixtures",
|
|
"llama_stack.providers.tests.agents.fixtures",
|
|
"llama_stack.providers.tests.datasetio.fixtures",
|
|
"llama_stack.providers.tests.scoring.fixtures",
|
|
"llama_stack.providers.tests.eval.fixtures",
|
|
"llama_stack.providers.tests.post_training.fixtures",
|
|
"llama_stack.providers.tests.tools.fixtures",
|
|
]
|