Tools API with brave and MCP providers (#639)

This PR adds a new Tools api and adds two tool runtime providers: brave
and MCP.

Test plan:
```
curl -X POST 'http://localhost:5000/alpha/toolgroups/register' \
-H 'Content-Type: application/json' \
-d '{ "tool_group_id": "simple_tool",
  "tool_group": {
    "type": "model_context_protocol",
    "endpoint": {"uri": "http://localhost:56000/sse"}
  },
  "provider_id": "model-context-protocol"
}'

 curl -X POST 'http://localhost:5000/alpha/toolgroups/register' \
-H 'Content-Type: application/json' \
-d '{
  "tool_group_id": "search", "provider_id": "brave-search",
  "tool_group": {
    "type": "user_defined",
    "tools": [
      {
        "name": "brave_search",
        "description": "A web search tool",
        "parameters": [
          {
            "name": "query",
            "parameter_type": "string",
            "description": "The query to search"
          }
        ],
        "metadata": {},
        "tool_prompt_format": "json"
      }
    ]
  }
}'

 curl -X GET http://localhost:5000/alpha/tools/list | jq .
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   662  100   662    0     0   333k      0 --:--:-- --:--:-- --:--:--  646k
[
  {
    "identifier": "brave_search",
    "provider_resource_id": "brave_search",
    "provider_id": "brave-search",
    "type": "tool",
    "tool_group": "search",
    "description": "A web search tool",
    "parameters": [
      {
        "name": "query",
        "parameter_type": "string",
        "description": "The query to search"
      }
    ],
    "metadata": {},
    "tool_prompt_format": "json"
  },
  {
    "identifier": "fetch",
    "provider_resource_id": "fetch",
    "provider_id": "model-context-protocol",
    "type": "tool",
    "tool_group": "simple_tool",
    "description": "Fetches a website and returns its content",
    "parameters": [
      {
        "name": "url",
        "parameter_type": "string",
        "description": "URL to fetch"
      }
    ],
    "metadata": {
      "endpoint": "http://localhost:56000/sse"
    },
    "tool_prompt_format": "json"
  }
]

curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \
-H 'Content-Type: application/json' \
-d '{
    "tool_name": "fetch",
    "args": {
        "url": "http://google.com/"
    }
}'

 curl -X POST 'http://localhost:5000/alpha/tool-runtime/invoke' \
-H 'Content-Type: application/json' -H 'X-LlamaStack-ProviderData: {"api_key": "<KEY>"}' \
-d '{
    "tool_name": "brave_search",
    "args": {
        "query": "who is meta ceo"
    }
}'
```
This commit is contained in:
Dinesh Yeduguru 2024-12-19 21:25:17 -08:00 committed by GitHub
parent 17fdb47e5e
commit c8be0bf1c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 633 additions and 24 deletions

View file

@ -8,19 +8,20 @@ from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTaskInput
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
@ -37,6 +38,8 @@ RoutableObject = Union[
Dataset,
ScoringFn,
EvalTask,
Tool,
ToolGroup,
]
@ -48,6 +51,8 @@ RoutableObjectWithProvider = Annotated[
Dataset,
ScoringFn,
EvalTask,
Tool,
ToolGroup,
],
Field(discriminator="type"),
]
@ -59,6 +64,7 @@ RoutedProtocol = Union[
DatasetIO,
Scoring,
Eval,
ToolRuntime,
]

View file

@ -47,6 +47,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.eval_tasks,
router_api=Api.eval,
),
AutoRoutedApiInfo(
routing_table_api=Api.tool_groups,
router_api=Api.tool_runtime,
),
]

View file

@ -30,6 +30,7 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
@ -60,12 +61,15 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.eval: Eval,
Api.eval_tasks: EvalTasks,
Api.post_training: PostTraining,
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
}
def additional_protocols_map() -> Dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),

View file

@ -7,7 +7,6 @@
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.store import DistributionRegistry
from .routing_tables import (
@ -17,6 +16,7 @@ from .routing_tables import (
ModelsRoutingTable,
ScoringFunctionsRoutingTable,
ShieldsRoutingTable,
ToolGroupsRoutingTable,
)
@ -33,6 +33,7 @@ async def get_routing_table_impl(
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
"eval_tasks": EvalTasksRoutingTable,
"tool_groups": ToolGroupsRoutingTable,
}
if api.value not in api_to_tables:
@ -51,6 +52,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
MemoryRouter,
SafetyRouter,
ScoringRouter,
ToolRuntimeRouter,
)
api_to_routers = {
@ -60,6 +62,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"datasetio": DatasetIORouter,
"scoring": ScoringRouter,
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")

View file

@ -6,15 +6,16 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.datasetio.datasetio import DatasetIO
from llama_stack.apis.memory_banks.memory_banks import BankParams
from llama_stack.distribution.datatypes import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.datasetio.datasetio import DatasetIO
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks.memory_banks import BankParams
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.tools import * # noqa: F403
from llama_stack.distribution.datatypes import RoutingTable
class MemoryRouter(Memory):
@ -372,3 +373,28 @@ class EvalRouter(Eval):
task_id,
job_id,
)
class ToolRuntimeRouter(ToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
args=args,
)
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
return await self.routing_table.get_provider_impl(
tool_group.name
).discover_tools(tool_group)

View file

@ -6,21 +6,19 @@
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from pydantic import parse_obj_as
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.tools import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.store import DistributionRegistry
def get_impl_api(p: Any) -> Api:
@ -45,6 +43,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_eval_task(obj)
elif api == Api.tool_runtime:
return await p.register_tool(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_tool(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@ -104,6 +106,8 @@ class CommonRoutingTableImpl(RoutingTable):
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
@ -125,6 +129,8 @@ class CommonRoutingTableImpl(RoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task")
elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool")
else:
raise ValueError("Unknown routing table type")
@ -461,3 +467,88 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
provider_resource_id=provider_eval_task_id,
)
await self.register_object(eval_task)
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
tools = await self.get_all_with_type("tool")
if tool_group_id:
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
return tools
async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", tool_group_id)
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
async def register_tool_group(
self,
tool_group_id: str,
tool_group: ToolGroupDef,
provider_id: Optional[str] = None,
) -> None:
tools = []
tool_defs = []
if provider_id is None:
if len(self.impls_by_provider_id.keys()) > 1:
raise ValueError(
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
)
provider_id = list(self.impls_by_provider_id.keys())[0]
if isinstance(tool_group, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group
)
elif isinstance(tool_group, UserDefinedToolGroupDef):
tool_defs = tool_group.tools
else:
raise ValueError(f"Unknown tool group: {tool_group}")
for tool_def in tool_defs:
tools.append(
Tool(
identifier=tool_def.name,
tool_group=tool_group_id,
description=tool_def.description,
parameters=tool_def.parameters,
provider_id=provider_id,
tool_prompt_format=tool_def.tool_prompt_format,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
existing_dict = existing_tool.model_dump()
new_dict = tool.model_dump()
if existing_dict != new_dict:
raise ValueError(
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
)
await self.register_object(tool)
await self.dist_registry.register(
ToolGroup(
identifier=tool_group_id,
provider_id=provider_id,
provider_resource_id=tool_group_id,
)
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
tool_group = await self.get_tool_group(tool_group_id)
if tool_group is None:
raise ValueError(f"Tool group {tool_group_id} not found")
tools = await self.list_tools(tool_group_id)
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)