chore: made inbuilt tools blocking calls into async non blocking calls

This commit is contained in:
sarthakdeshpande 2025-03-09 14:09:06 +05:30
parent ba917a9c48
commit 1920c65f61
4 changed files with 34 additions and 25 deletions

View file

@ -7,7 +7,7 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -31,7 +31,7 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
@ -77,7 +77,8 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
"q": kwargs["query"], "q": kwargs["query"],
} }
response = requests.get( async with httpx.AsyncClient() as client:
response = await client.get(
url=self.url, url=self.url,
params=params, params=params,
headers=headers, headers=headers,

View file

@ -6,7 +6,7 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -30,7 +30,7 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
@ -74,7 +74,12 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
"Accept": "application/json", "Accept": "application/json",
} }
payload = {"q": kwargs["query"]} payload = {"q": kwargs["query"]}
response = requests.get(url=url, params=payload, headers=headers) async with httpx.AsyncClient() as client:
response = await client.get(
url=url,
params=payload,
headers=headers,
)
response.raise_for_status() response.raise_for_status()
results = self._clean_brave_response(response.json()) results = self._clean_brave_response(response.json())
content_items = "\n".join([str(result) for result in results]) content_items = "\n".join([str(result) for result in results])

View file

@ -7,7 +7,7 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -30,7 +30,7 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
@ -66,10 +66,12 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()
response = requests.post( async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.tavily.com/search", "https://api.tavily.com/search",
json={"api_key": api_key, "query": kwargs["query"]}, json={"api_key": api_key, "query": kwargs["query"]},
) )
response.raise_for_status()
return ToolInvocationResult(content=json.dumps(self._clean_tavily_response(response.json()))) return ToolInvocationResult(content=json.dumps(self._clean_tavily_response(response.json())))

View file

@ -7,7 +7,7 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -31,7 +31,7 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
@ -73,11 +73,12 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
"format": "plaintext", "format": "plaintext",
"output": "json", "output": "json",
} }
response = requests.get( async with httpx.AsyncClient() as client:
self.url, response = await client.get(
params=params, params=params,
url=self.url
) )
response.raise_for_status()
return ToolInvocationResult(content=json.dumps(self._clean_wolfram_alpha_response(response.json()))) return ToolInvocationResult(content=json.dumps(self._clean_wolfram_alpha_response(response.json())))
def _clean_wolfram_alpha_response(self, wa_response): def _clean_wolfram_alpha_response(self, wa_response):