forked from phoenix-oss/llama-stack-mirror
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
185 lines
6.5 KiB
Python
185 lines
6.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from pathlib import Path
|
|
from typing import Dict, List, Literal, Optional, Tuple
|
|
|
|
import jinja2
|
|
import yaml
|
|
from pydantic import BaseModel, Field
|
|
|
|
from llama_stack.apis.models.models import ModelType
|
|
from llama_stack.distribution.datatypes import (
|
|
Api,
|
|
BuildConfig,
|
|
DistributionSpec,
|
|
ModelInput,
|
|
Provider,
|
|
ShieldInput,
|
|
StackRunConfig,
|
|
ToolGroupInput,
|
|
)
|
|
from llama_stack.distribution.distribution import get_provider_registry
|
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
|
|
|
|
class RunConfigSettings(BaseModel):
|
|
provider_overrides: Dict[str, List[Provider]] = Field(default_factory=dict)
|
|
default_models: Optional[List[ModelInput]] = None
|
|
default_shields: Optional[List[ShieldInput]] = None
|
|
default_tool_groups: Optional[List[ToolGroupInput]] = None
|
|
|
|
def run_config(
|
|
self,
|
|
name: str,
|
|
providers: Dict[str, List[str]],
|
|
docker_image: Optional[str] = None,
|
|
) -> StackRunConfig:
|
|
provider_registry = get_provider_registry()
|
|
|
|
provider_configs = {}
|
|
for api_str, provider_types in providers.items():
|
|
if api_providers := self.provider_overrides.get(api_str):
|
|
provider_configs[api_str] = api_providers
|
|
continue
|
|
|
|
provider_configs[api_str] = []
|
|
for provider_type in provider_types:
|
|
provider_id = provider_type.split("::")[-1]
|
|
|
|
api = Api(api_str)
|
|
if provider_type not in provider_registry[api]:
|
|
raise ValueError(
|
|
f"Unknown provider type: {provider_type} for API: {api_str}"
|
|
)
|
|
|
|
config_class = provider_registry[api][provider_type].config_class
|
|
assert (
|
|
config_class is not None
|
|
), f"No config class for provider type: {provider_type} for API: {api_str}"
|
|
|
|
config_class = instantiate_class_type(config_class)
|
|
if hasattr(config_class, "sample_run_config"):
|
|
config = config_class.sample_run_config(
|
|
__distro_dir__=f"distributions/{name}"
|
|
)
|
|
else:
|
|
config = {}
|
|
|
|
provider_configs[api_str].append(
|
|
Provider(
|
|
provider_id=provider_id,
|
|
provider_type=provider_type,
|
|
config=config,
|
|
)
|
|
)
|
|
|
|
# Get unique set of APIs from providers
|
|
apis = list(sorted(providers.keys()))
|
|
|
|
return StackRunConfig(
|
|
image_name=name,
|
|
docker_image=docker_image,
|
|
conda_env=name,
|
|
apis=apis,
|
|
providers=provider_configs,
|
|
metadata_store=SqliteKVStoreConfig.sample_run_config(
|
|
__distro_dir__=f"distributions/{name}",
|
|
db_name="registry.db",
|
|
),
|
|
models=self.default_models or [],
|
|
shields=self.default_shields or [],
|
|
tool_groups=self.default_tool_groups or [],
|
|
)
|
|
|
|
|
|
class DistributionTemplate(BaseModel):
|
|
"""
|
|
Represents a Llama Stack distribution instance that can generate configuration
|
|
and documentation files.
|
|
"""
|
|
|
|
name: str
|
|
description: str
|
|
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
|
|
|
|
providers: Dict[str, List[str]]
|
|
run_configs: Dict[str, RunConfigSettings]
|
|
template_path: Optional[Path] = None
|
|
|
|
# Optional configuration
|
|
run_config_env_vars: Optional[Dict[str, Tuple[str, str]]] = None
|
|
docker_image: Optional[str] = None
|
|
|
|
default_models: Optional[List[ModelInput]] = None
|
|
|
|
def build_config(self) -> BuildConfig:
|
|
return BuildConfig(
|
|
name=self.name,
|
|
distribution_spec=DistributionSpec(
|
|
description=self.description,
|
|
docker_image=self.docker_image,
|
|
providers=self.providers,
|
|
),
|
|
image_type="conda", # default to conda, can be overridden
|
|
)
|
|
|
|
def generate_markdown_docs(self) -> str:
|
|
providers_table = "| API | Provider(s) |\n"
|
|
providers_table += "|-----|-------------|\n"
|
|
|
|
for api, providers in sorted(self.providers.items()):
|
|
providers_str = ", ".join(f"`{p}`" for p in providers)
|
|
providers_table += f"| {api} | {providers_str} |\n"
|
|
|
|
template = self.template_path.read_text()
|
|
# Render template with rich-generated table
|
|
env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
|
|
template = env.from_string(template)
|
|
return template.render(
|
|
name=self.name,
|
|
description=self.description,
|
|
providers=self.providers,
|
|
providers_table=providers_table,
|
|
run_config_env_vars=self.run_config_env_vars,
|
|
default_models=self.default_models,
|
|
)
|
|
|
|
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
|
|
def enum_representer(dumper, data):
|
|
return dumper.represent_scalar("tag:yaml.org,2002:str", data.value)
|
|
|
|
# Register YAML representer for ModelType
|
|
yaml.add_representer(ModelType, enum_representer)
|
|
yaml.SafeDumper.add_representer(ModelType, enum_representer)
|
|
|
|
for output_dir in [yaml_output_dir, doc_output_dir]:
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
build_config = self.build_config()
|
|
with open(yaml_output_dir / "build.yaml", "w") as f:
|
|
yaml.safe_dump(
|
|
build_config.model_dump(exclude_none=True),
|
|
f,
|
|
sort_keys=False,
|
|
)
|
|
|
|
for yaml_pth, settings in self.run_configs.items():
|
|
run_config = settings.run_config(
|
|
self.name, self.providers, self.docker_image
|
|
)
|
|
with open(yaml_output_dir / yaml_pth, "w") as f:
|
|
yaml.safe_dump(
|
|
run_config.model_dump(exclude_none=True),
|
|
f,
|
|
sort_keys=False,
|
|
)
|
|
|
|
if self.template_path:
|
|
docs = self.generate_markdown_docs()
|
|
with open(doc_output_dir / f"{self.name}.md", "w") as f:
|
|
f.write(docs if docs.endswith("\n") else docs + "\n")
|