mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
224 lines
7.2 KiB
Python
224 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
import pkg_resources
|
|
import yaml
|
|
from termcolor import colored
|
|
|
|
from llama_stack.apis.agents import Agents
|
|
from llama_stack.apis.batch_inference import BatchInference
|
|
from llama_stack.apis.datasetio import DatasetIO
|
|
from llama_stack.apis.datasets import Datasets
|
|
from llama_stack.apis.eval import Eval
|
|
from llama_stack.apis.eval_tasks import EvalTasks
|
|
from llama_stack.apis.inference import Inference
|
|
from llama_stack.apis.inspect import Inspect
|
|
from llama_stack.apis.memory import Memory
|
|
from llama_stack.apis.memory_banks import MemoryBanks
|
|
from llama_stack.apis.models import Models
|
|
from llama_stack.apis.post_training import PostTraining
|
|
from llama_stack.apis.safety import Safety
|
|
from llama_stack.apis.scoring import Scoring
|
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
|
from llama_stack.apis.shields import Shields
|
|
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
|
from llama_stack.apis.telemetry import Telemetry
|
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
|
from llama_stack.distribution.datatypes import StackRunConfig
|
|
from llama_stack.distribution.distribution import get_provider_registry
|
|
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
|
from llama_stack.distribution.store.registry import create_dist_registry
|
|
from llama_stack.providers.datatypes import Api
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
LLAMA_STACK_API_VERSION = "alpha"
|
|
|
|
|
|
class LlamaStack(
|
|
MemoryBanks,
|
|
Inference,
|
|
BatchInference,
|
|
Agents,
|
|
Safety,
|
|
SyntheticDataGeneration,
|
|
Datasets,
|
|
Telemetry,
|
|
PostTraining,
|
|
Memory,
|
|
Eval,
|
|
EvalTasks,
|
|
Scoring,
|
|
ScoringFunctions,
|
|
DatasetIO,
|
|
Models,
|
|
Shields,
|
|
Inspect,
|
|
ToolGroups,
|
|
ToolRuntime,
|
|
):
|
|
pass
|
|
|
|
|
|
RESOURCES = [
|
|
("models", Api.models, "register_model", "list_models"),
|
|
("shields", Api.shields, "register_shield", "list_shields"),
|
|
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
|
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
|
(
|
|
"scoring_fns",
|
|
Api.scoring_functions,
|
|
"register_scoring_function",
|
|
"list_scoring_functions",
|
|
),
|
|
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
|
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
|
]
|
|
|
|
|
|
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
|
for rsrc, api, register_method, list_method in RESOURCES:
|
|
objects = getattr(run_config, rsrc)
|
|
if api not in impls:
|
|
continue
|
|
|
|
method = getattr(impls[api], register_method)
|
|
for obj in objects:
|
|
await method(**obj.model_dump())
|
|
|
|
method = getattr(impls[api], list_method)
|
|
for obj in await method():
|
|
log.info(
|
|
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
|
)
|
|
|
|
log.info("")
|
|
|
|
|
|
class EnvVarError(Exception):
|
|
def __init__(self, var_name: str, path: str = ""):
|
|
self.var_name = var_name
|
|
self.path = path
|
|
super().__init__(
|
|
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
|
|
)
|
|
|
|
|
|
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Redact sensitive information from config before printing."""
|
|
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
|
|
|
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
|
result = {}
|
|
for k, v in d.items():
|
|
if isinstance(v, dict):
|
|
result[k] = _redact_dict(v)
|
|
elif isinstance(v, list):
|
|
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
|
|
elif any(pattern in k.lower() for pattern in sensitive_patterns):
|
|
result[k] = "********"
|
|
else:
|
|
result[k] = v
|
|
return result
|
|
|
|
return _redact_dict(data)
|
|
|
|
|
|
def replace_env_vars(config: Any, path: str = "") -> Any:
|
|
if isinstance(config, dict):
|
|
result = {}
|
|
for k, v in config.items():
|
|
try:
|
|
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
|
|
except EnvVarError as e:
|
|
raise EnvVarError(e.var_name, e.path) from None
|
|
return result
|
|
|
|
elif isinstance(config, list):
|
|
result = []
|
|
for i, v in enumerate(config):
|
|
try:
|
|
result.append(replace_env_vars(v, f"{path}[{i}]"))
|
|
except EnvVarError as e:
|
|
raise EnvVarError(e.var_name, e.path) from None
|
|
return result
|
|
|
|
elif isinstance(config, str):
|
|
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
|
|
|
|
def get_env_var(match):
|
|
env_var = match.group(1)
|
|
default_val = match.group(2)
|
|
|
|
value = os.environ.get(env_var)
|
|
if not value:
|
|
if default_val is None:
|
|
raise EnvVarError(env_var, path)
|
|
else:
|
|
value = default_val
|
|
|
|
# expand "~" from the values
|
|
return os.path.expanduser(value)
|
|
|
|
try:
|
|
return re.sub(pattern, get_env_var, config)
|
|
except EnvVarError as e:
|
|
raise EnvVarError(e.var_name, e.path) from None
|
|
|
|
return config
|
|
|
|
|
|
def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|
"""Validate and split an environment variable key-value pair."""
|
|
try:
|
|
key, value = env_pair.split("=", 1)
|
|
key = key.strip()
|
|
if not key:
|
|
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
|
|
if not all(c.isalnum() or c == "_" for c in key):
|
|
raise ValueError(
|
|
f"Key must contain only alphanumeric characters and underscores: {key}"
|
|
)
|
|
return key, value
|
|
except ValueError as e:
|
|
raise ValueError(
|
|
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
|
|
) from e
|
|
|
|
|
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
|
# asked for in the run config.
|
|
async def construct_stack(
|
|
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
|
) -> Dict[Api, Any]:
|
|
dist_registry, _ = await create_dist_registry(
|
|
run_config.metadata_store, run_config.image_name
|
|
)
|
|
impls = await resolve_impls(
|
|
run_config, provider_registry or get_provider_registry(), dist_registry
|
|
)
|
|
await register_resources(run_config, impls)
|
|
return impls
|
|
|
|
|
|
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
|
template_path = pkg_resources.resource_filename(
|
|
"llama_stack", f"templates/{template}/run.yaml"
|
|
)
|
|
|
|
if not Path(template_path).exists():
|
|
raise ValueError(f"Template '{template}' not found at {template_path}")
|
|
|
|
with open(template_path) as f:
|
|
run_config = yaml.safe_load(f)
|
|
|
|
return StackRunConfig(**replace_env_vars(run_config))
|