mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
working tools runtime
This commit is contained in:
parent
744eb0888c
commit
84d01fe8f8
6 changed files with 224 additions and 5 deletions
|
@ -49,6 +49,10 @@ class Tool(Resource):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolStore(Protocol):
|
||||||
|
def get_tool(self, identifier: str) -> Tool: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Tools(Protocol):
|
class Tools(Protocol):
|
||||||
|
@ -88,6 +92,8 @@ class Tools(Protocol):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class ToolRuntime(Protocol):
|
class ToolRuntime(Protocol):
|
||||||
|
tool_store: ToolStore
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
|
||||||
"""Run a tool with the given arguments"""
|
"""Run a tool with the given arguments"""
|
||||||
|
|
|
@ -108,6 +108,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
await add_objects(scoring_functions, pid, ScoringFn)
|
await add_objects(scoring_functions, pid, ScoringFn)
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
p.eval_task_store = self
|
p.eval_task_store = self
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
p.tool_store = self
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
|
@ -129,6 +131,8 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
return ("Scoring", "scoring_function")
|
return ("Scoring", "scoring_function")
|
||||||
elif isinstance(self, EvalTasksRoutingTable):
|
elif isinstance(self, EvalTasksRoutingTable):
|
||||||
return ("Eval", "eval_task")
|
return ("Eval", "eval_task")
|
||||||
|
elif isinstance(self, ToolsRoutingTable):
|
||||||
|
return ("Tools", "tool")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown routing table type")
|
raise ValueError("Unknown routing table type")
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,21 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .config import MetaReferenceToolRuntimeConfig
|
from .config import MetaReferenceToolRuntimeConfig
|
||||||
from .meta_reference import MetaReferenceToolRuntimeImpl
|
from .meta_reference import MetaReferenceToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps):
|
class MetaReferenceProviderDataValidator(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: MetaReferenceToolRuntimeConfig, _deps: Dict[str, Any]
|
||||||
|
):
|
||||||
impl = MetaReferenceToolRuntimeImpl(config)
|
impl = MetaReferenceToolRuntimeImpl(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -0,0 +1,165 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def bing_search(query: str, __api_key__: str, top_k: int = 3, **kwargs) -> str:
|
||||||
|
url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
headers = {
|
||||||
|
"Ocp-Apim-Subscription-Key": __api_key__,
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
"count": top_k,
|
||||||
|
"textDecorations": True,
|
||||||
|
"textFormat": "HTML",
|
||||||
|
"q": query,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(url=url, params=params, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
clean = _bing_clean_response(response.json())
|
||||||
|
return json.dumps(clean)
|
||||||
|
|
||||||
|
|
||||||
|
def _bing_clean_response(search_response):
|
||||||
|
clean_response = []
|
||||||
|
query = search_response["queryContext"]["originalQuery"]
|
||||||
|
if "webPages" in search_response:
|
||||||
|
pages = search_response["webPages"]["value"]
|
||||||
|
for p in pages:
|
||||||
|
selected_keys = {"name", "url", "snippet"}
|
||||||
|
clean_response.append({k: v for k, v in p.items() if k in selected_keys})
|
||||||
|
if "news" in search_response:
|
||||||
|
clean_news = []
|
||||||
|
news = search_response["news"]["value"]
|
||||||
|
for n in news:
|
||||||
|
selected_keys = {"name", "url", "description"}
|
||||||
|
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||||
|
|
||||||
|
clean_response.append(clean_news)
|
||||||
|
|
||||||
|
return {"query": query, "top_k": clean_response}
|
||||||
|
|
||||||
|
|
||||||
|
async def brave_search(query: str, __api_key__: str) -> str:
|
||||||
|
url = "https://api.search.brave.com/res/v1/web/search"
|
||||||
|
headers = {
|
||||||
|
"X-Subscription-Token": __api_key__,
|
||||||
|
"Accept-Encoding": "gzip",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
payload = {"q": query}
|
||||||
|
response = requests.get(url=url, params=payload, headers=headers)
|
||||||
|
return json.dumps(_clean_brave_response(response.json()))
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_brave_response(search_response, top_k=3):
|
||||||
|
query = None
|
||||||
|
clean_response = []
|
||||||
|
if "query" in search_response:
|
||||||
|
if "original" in search_response["query"]:
|
||||||
|
query = search_response["query"]["original"]
|
||||||
|
if "mixed" in search_response:
|
||||||
|
mixed_results = search_response["mixed"]
|
||||||
|
for m in mixed_results["main"][:top_k]:
|
||||||
|
r_type = m["type"]
|
||||||
|
results = search_response[r_type]["results"]
|
||||||
|
if r_type == "web":
|
||||||
|
# For web data - add a single output from the search
|
||||||
|
idx = m["index"]
|
||||||
|
selected_keys = [
|
||||||
|
"type",
|
||||||
|
"title",
|
||||||
|
"url",
|
||||||
|
"description",
|
||||||
|
"date",
|
||||||
|
"extra_snippets",
|
||||||
|
]
|
||||||
|
cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}
|
||||||
|
elif r_type == "faq":
|
||||||
|
# For faw data - take a list of all the questions & answers
|
||||||
|
selected_keys = ["type", "question", "answer", "title", "url"]
|
||||||
|
cleaned = []
|
||||||
|
for q in results:
|
||||||
|
cleaned.append({k: v for k, v in q.items() if k in selected_keys})
|
||||||
|
elif r_type == "infobox":
|
||||||
|
idx = m["index"]
|
||||||
|
selected_keys = [
|
||||||
|
"type",
|
||||||
|
"title",
|
||||||
|
"url",
|
||||||
|
"description",
|
||||||
|
"long_desc",
|
||||||
|
]
|
||||||
|
cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}
|
||||||
|
elif r_type == "videos":
|
||||||
|
selected_keys = [
|
||||||
|
"type",
|
||||||
|
"url",
|
||||||
|
"title",
|
||||||
|
"description",
|
||||||
|
"date",
|
||||||
|
]
|
||||||
|
cleaned = []
|
||||||
|
for q in results:
|
||||||
|
cleaned.append({k: v for k, v in q.items() if k in selected_keys})
|
||||||
|
elif r_type == "locations":
|
||||||
|
# For faw data - take a list of all the questions & answers
|
||||||
|
selected_keys = [
|
||||||
|
"type",
|
||||||
|
"title",
|
||||||
|
"url",
|
||||||
|
"description",
|
||||||
|
"coordinates",
|
||||||
|
"postal_address",
|
||||||
|
"contact",
|
||||||
|
"rating",
|
||||||
|
"distance",
|
||||||
|
"zoom_level",
|
||||||
|
]
|
||||||
|
cleaned = []
|
||||||
|
for q in results:
|
||||||
|
cleaned.append({k: v for k, v in q.items() if k in selected_keys})
|
||||||
|
elif r_type == "news":
|
||||||
|
# For faw data - take a list of all the questions & answers
|
||||||
|
selected_keys = [
|
||||||
|
"type",
|
||||||
|
"title",
|
||||||
|
"url",
|
||||||
|
"description",
|
||||||
|
]
|
||||||
|
cleaned = []
|
||||||
|
for q in results:
|
||||||
|
cleaned.append({k: v for k, v in q.items() if k in selected_keys})
|
||||||
|
else:
|
||||||
|
cleaned = []
|
||||||
|
|
||||||
|
clean_response.append(cleaned)
|
||||||
|
|
||||||
|
return {"query": query, "top_k": clean_response}
|
||||||
|
|
||||||
|
|
||||||
|
async def tavily_search(query: str, __api_key__: str) -> str:
|
||||||
|
response = requests.post(
|
||||||
|
"https://api.tavily.com/search",
|
||||||
|
json={"api_key": __api_key__, "query": query},
|
||||||
|
)
|
||||||
|
return json.dumps(_clean_tavily_response(response.json()))
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_tavily_response(search_response, top_k=3):
|
||||||
|
return {"query": search_response["query"], "top_k": search_response["results"]}
|
||||||
|
|
||||||
|
|
||||||
|
async def print_tool(query: str, __api_key__: str) -> str:
|
||||||
|
logger.info(f"print_tool called with query: {query} and api_key: {__api_key__}")
|
||||||
|
return json.dumps({"result": "success"})
|
|
@ -4,15 +4,31 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import llama_stack.providers.inline.tool_runtime.meta_reference.builtins as builtins
|
||||||
|
|
||||||
from llama_stack.apis.tools import Tool, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolRuntime
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
from .config import MetaReferenceToolRuntimeConfig
|
from .config import MetaReferenceToolRuntimeConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|
||||||
|
class ToolType(Enum):
|
||||||
|
bing_search = "bing_search"
|
||||||
|
brave_search = "brave_search"
|
||||||
|
tavily_search = "tavily_search"
|
||||||
|
print_tool = "print_tool"
|
||||||
|
|
||||||
|
|
||||||
|
class MetaReferenceToolRuntimeImpl(
|
||||||
|
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||||
|
):
|
||||||
def __init__(self, config: MetaReferenceToolRuntimeConfig):
|
def __init__(self, config: MetaReferenceToolRuntimeConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -21,10 +37,27 @@ class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool):
|
async def register_tool(self, tool: Tool):
|
||||||
print(f"registering tool {tool.identifier}")
|
print(f"registering tool {tool.identifier}")
|
||||||
pass
|
if tool.provider_resource_id not in ToolType.__members__:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool {tool.identifier} not a supported tool by Meta Reference"
|
||||||
|
)
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
pass
|
raise NotImplementedError("Meta Reference does not support unregistering tools")
|
||||||
|
|
||||||
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
|
||||||
pass
|
tool = await self.tool_store.get_tool(tool_id)
|
||||||
|
if args.get("__api_key__") is not None:
|
||||||
|
logger.warning(
|
||||||
|
"__api_key__ is a reserved argument for this tool: {tool_id}"
|
||||||
|
)
|
||||||
|
args["__api_key__"] = self._get_api_key()
|
||||||
|
return await getattr(builtins, tool.provider_resource_id)(**args)
|
||||||
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.api_key
|
||||||
|
|
|
@ -17,5 +17,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.inline.tool_runtime.meta_reference",
|
module="llama_stack.providers.inline.tool_runtime.meta_reference",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceProviderDataValidator",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue