agents to use tools api

This commit is contained in:
Dinesh Yeduguru 2024-12-20 14:46:32 -08:00
parent 596afc6497
commit f90e9c2003
21 changed files with 538 additions and 329 deletions

View file

@ -161,6 +161,7 @@ a default SQLite store will be used.""",
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
class BuildConfig(BaseModel):

View file

@ -5,9 +5,7 @@
# the root directory of this source tree.
import importlib
import inspect
import logging
from typing import Any, Dict, List, Set
from llama_stack.apis.agents import Agents
@ -28,7 +26,6 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import (
AutoRoutedProviderSpec,
Provider,
@ -38,7 +35,7 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import (
Api,
DatasetsProtocolPrivate,

View file

@ -523,6 +523,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
)
provider_id = list(self.impls_by_provider_id.keys())[0]
# parse tool group to the type if dict
tool_group = parse_obj_as(ToolGroupDef, tool_group)
if isinstance(tool_group, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group

View file

@ -12,7 +12,7 @@ from typing import Any, Dict, Optional
import pkg_resources
import yaml
from llama_models.llama3.api.datatypes import * # noqa: F403
from termcolor import colored
from llama_stack.apis.agents import Agents
@ -33,14 +33,12 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
LLAMA_STACK_API_VERSION = "alpha"
@ -81,6 +79,7 @@ RESOURCES = [
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
]