chore: made inbuilt tools blocking calls into async non blocking calls (#1509)

# What does this PR do?
This PR converts blocking calls for in built tools like wolfram, brave,
tavily and bing into non blocking async calls
[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]
pytest -s -v tool_runtime/test_builtin_tools.py --stack-config=together
--text-model=meta-llama/Llama-3.1-8B-Instruct
Used the command above to get the below results
<img width="1710" alt="image"
src="https://github.com/user-attachments/assets/76b0ca06-f6e4-45fa-a114-0449bef2325b"
/>


<img width="1389" alt="image"
src="https://github.com/user-attachments/assets/5220ccbb-7882-4240-b17e-f362ad46d25b"
/>

<img width="1432" alt="image"
src="https://github.com/user-attachments/assets/bb93a41e-e82a-4c98-a22d-6b0e320aa974"
/>

[//]: # (## Documentation)

---------

Co-authored-by: sarthakdeshpande <sarthak.deshpande@engati.com>
This commit is contained in:
Sarthak Deshpande 2025-03-10 05:29:24 +05:30 committed by GitHub
parent 70ff226b6a
commit a9c5d3cd3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 45 additions and 33 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import base64 import base64
import io import io
import json import json
@ -99,7 +100,7 @@ class FaissIndex(EmbeddingIndex):
await self._save_index() await self._save_index()
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
distances, indices = self.index.search(embedding.reshape(1, -1).astype(np.float32), k) distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
chunks = [] chunks = []
scores = [] scores = []

View file

@ -7,7 +7,7 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -31,7 +31,7 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
@ -77,12 +77,13 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
"q": kwargs["query"], "q": kwargs["query"],
} }
response = requests.get( async with httpx.AsyncClient() as client:
url=self.url, response = await client.get(
params=params, url=self.url,
headers=headers, params=params,
) headers=headers,
response.raise_for_status() )
response.raise_for_status()
return ToolInvocationResult(content=json.dumps(self._clean_response(response.json()))) return ToolInvocationResult(content=json.dumps(self._clean_response(response.json())))

View file

@ -6,7 +6,7 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -30,7 +30,7 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
@ -74,8 +74,13 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
"Accept": "application/json", "Accept": "application/json",
} }
payload = {"q": kwargs["query"]} payload = {"q": kwargs["query"]}
response = requests.get(url=url, params=payload, headers=headers) async with httpx.AsyncClient() as client:
response.raise_for_status() response = await client.get(
url=url,
params=payload,
headers=headers,
)
response.raise_for_status()
results = self._clean_brave_response(response.json()) results = self._clean_brave_response(response.json())
content_items = "\n".join([str(result) for result in results]) content_items = "\n".join([str(result) for result in results])
return ToolInvocationResult( return ToolInvocationResult(

View file

@ -7,7 +7,7 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -30,7 +30,7 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
@ -66,10 +66,12 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()
response = requests.post( async with httpx.AsyncClient() as client:
"https://api.tavily.com/search", response = await client.post(
json={"api_key": api_key, "query": kwargs["query"]}, "https://api.tavily.com/search",
) json={"api_key": api_key, "query": kwargs["query"]},
)
response.raise_for_status()
return ToolInvocationResult(content=json.dumps(self._clean_tavily_response(response.json()))) return ToolInvocationResult(content=json.dumps(self._clean_tavily_response(response.json())))

View file

@ -7,7 +7,7 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -31,7 +31,7 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
@ -73,11 +73,9 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
"format": "plaintext", "format": "plaintext",
"output": "json", "output": "json",
} }
response = requests.get( async with httpx.AsyncClient() as client:
self.url, response = await client.get(params=params, url=self.url)
params=params, response.raise_for_status()
)
return ToolInvocationResult(content=json.dumps(self._clean_wolfram_alpha_response(response.json()))) return ToolInvocationResult(content=json.dumps(self._clean_wolfram_alpha_response(response.json())))
def _clean_wolfram_alpha_response(self, wa_response): def _clean_wolfram_alpha_response(self, wa_response):

View file

@ -8,9 +8,11 @@ import logging
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from pymongo import MongoClient from pymongo import AsyncMongoClient
from llama_stack.providers.utils.kvstore import KVStore, MongoDBKVStoreConfig from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -30,7 +32,7 @@ class MongoDBKVStoreImpl(KVStore):
"password": self.config.password, "password": self.config.password,
} }
conn_creds = {k: v for k, v in conn_creds.items() if v is not None} conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
self.conn = MongoClient(**conn_creds) self.conn = AsyncMongoClient(**conn_creds)
self.collection = self.conn[self.config.db][self.config.collection_name] self.collection = self.conn[self.config.db][self.config.collection_name]
except Exception as e: except Exception as e:
log.exception("Could not connect to MongoDB database server") log.exception("Could not connect to MongoDB database server")
@ -44,17 +46,17 @@ class MongoDBKVStoreImpl(KVStore):
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None: async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
key = self._namespaced_key(key) key = self._namespaced_key(key)
update_query = {"$set": {"value": value, "expiration": expiration}} update_query = {"$set": {"value": value, "expiration": expiration}}
self.collection.update_one({"key": key}, update_query, upsert=True) await self.collection.update_one({"key": key}, update_query, upsert=True)
async def get(self, key: str) -> Optional[str]: async def get(self, key: str) -> Optional[str]:
key = self._namespaced_key(key) key = self._namespaced_key(key)
query = {"key": key} query = {"key": key}
result = self.collection.find_one(query, {"value": 1, "_id": 0}) result = await self.collection.find_one(query, {"value": 1, "_id": 0})
return result["value"] if result else None return result["value"] if result else None
async def delete(self, key: str) -> None: async def delete(self, key: str) -> None:
key = self._namespaced_key(key) key = self._namespaced_key(key)
self.collection.delete_one({"key": key}) await self.collection.delete_one({"key": key})
async def range(self, start_key: str, end_key: str) -> List[str]: async def range(self, start_key: str, end_key: str) -> List[str]:
start_key = self._namespaced_key(start_key) start_key = self._namespaced_key(start_key)
@ -63,4 +65,7 @@ class MongoDBKVStoreImpl(KVStore):
"key": {"$gte": start_key, "$lt": end_key}, "key": {"$gte": start_key, "$lt": end_key},
} }
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1) cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
return [doc["value"] for doc in cursor] result = []
async for doc in cursor:
result.append(doc["value"])
return result