llama-stack/llama_stack/providers/tests/resolver.py
Dinesh Yeduguru a5c57cd381
agents to use tools api (#673)
# What does this PR do?

PR #639 introduced the notion of Tools API and ability to invoke tools
through API just as any resource. This PR changes the Agents to start
using the Tools API to invoke tools. Major changes include:
1) Ability to specify tool groups with AgentConfig
2) Agent gets the corresponding tool definitions for the specified tools
and pass along to the model
3) Attachements are now named as Documents and their behavior is mostly
unchanged from user perspective
4) You can specify args that can be injected to a tool call through
Agent config. This is especially useful in case of memory tool, where
you want the tool to operate on a specific memory bank.
5) You can also register tool groups with args, which lets the agent
inject these as well into the tool call.
6) All tests have been migrated to use new tools API and fixtures
including client SDK tests
7) Telemetry just works with tools API because of our trace protocol
decorator


## Test Plan
```
pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

pytest -s -v -k together  llama_stack/providers/tests/tools/test_tools.py \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py
```
run.yaml:
https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994

Notebook:
https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
2025-01-08 19:01:00 -08:00

103 lines
3.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import tempfile
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.eval_tasks import EvalTaskInput
from llama_stack.apis.memory_banks import MemoryBankInput
from llama_stack.apis.models import ModelInput
from llama_stack.apis.scoring_functions import ScoringFnInput
from llama_stack.apis.shields import ShieldInput
from llama_stack.apis.tools import ToolGroupInput
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_remote_stack_impls
from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.datatypes import Api, RemoteProviderConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestStack(BaseModel):
impls: Dict[Api, Any]
run_config: StackRunConfig
async def construct_stack_for_test(
apis: List[Api],
providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None,
models: Optional[List[ModelInput]] = None,
shields: Optional[List[ShieldInput]] = None,
memory_banks: Optional[List[MemoryBankInput]] = None,
datasets: Optional[List[DatasetInput]] = None,
scoring_fns: Optional[List[ScoringFnInput]] = None,
eval_tasks: Optional[List[EvalTaskInput]] = None,
tool_groups: Optional[List[ToolGroupInput]] = None,
) -> TestStack:
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict(
image_name="test-fixture",
apis=apis,
providers=providers,
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
models=models or [],
shields=shields or [],
memory_banks=memory_banks or [],
datasets=datasets or [],
scoring_fns=scoring_fns or [],
eval_tasks=eval_tasks or [],
tool_groups=tool_groups or [],
)
run_config = parse_and_maybe_upgrade_config(run_config)
try:
remote_config = remote_provider_config(run_config)
if not remote_config:
# TODO: add to provider registry by creating interesting mocks or fakes
impls = await construct_stack(run_config, get_provider_registry())
else:
# we don't register resources for a remote stack as part of the fixture setup
# because the stack is already "up". if a test needs to register resources, it
# can do so manually always.
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
test_stack = TestStack(impls=impls, run_config=run_config)
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return test_stack
def remote_provider_config(
run_config: StackRunConfig,
) -> Optional[RemoteProviderConfig]:
remote_config = None
has_non_remote = False
for api_providers in run_config.providers.values():
for provider in api_providers:
if provider.provider_type == "test::remote":
remote_config = RemoteProviderConfig(**provider.config)
else:
has_non_remote = True
if remote_config:
assert not has_non_remote, "Remote stack cannot have non-remote providers"
return remote_config