mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:02:26 +00:00
agents to use tools api
This commit is contained in:
parent
596afc6497
commit
f90e9c2003
21 changed files with 538 additions and 329 deletions
|
|
@ -161,6 +161,7 @@ a default SQLite store will be used.""",
|
|||
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -5,9 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
|
@ -28,7 +26,6 @@ from llama_stack.apis.shields import Shields
|
|||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AutoRoutedProviderSpec,
|
||||
Provider,
|
||||
|
|
@ -38,7 +35,7 @@ from llama_stack.distribution.datatypes import (
|
|||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
DatasetsProtocolPrivate,
|
||||
|
|
|
|||
|
|
@ -523,6 +523,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
)
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
# parse tool group to the type if dict
|
||||
tool_group = parse_obj_as(ToolGroupDef, tool_group)
|
||||
if isinstance(tool_group, MCPToolGroupDef):
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||
tool_group
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Any, Dict, Optional
|
|||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
|
@ -33,14 +33,12 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LLAMA_STACK_API_VERSION = "alpha"
|
||||
|
|
@ -81,6 +79,7 @@ RESOURCES = [
|
|||
"list_scoring_functions",
|
||||
),
|
||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue