Merge remote-tracking branch 'origin/api_keys' into api_updates_2

This commit is contained in:
Ashwin Bharambe 2024-09-17 14:31:22 -07:00
commit ed351b8276
5 changed files with 16 additions and 23 deletions

View file

@ -51,6 +51,7 @@ class SearchToolDefinition(ToolDefinitionCommon):
# NOTE: brave_search is just a placeholder since model always uses # NOTE: brave_search is just a placeholder since model always uses
# brave_search as tool call name # brave_search as tool call name
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
api_key: str
engine: SearchEngineType = SearchEngineType.brave engine: SearchEngineType = SearchEngineType.brave
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None
@ -58,6 +59,7 @@ class SearchToolDefinition(ToolDefinitionCommon):
@json_schema_type @json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon): class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
api_key: str
remote_execution: Optional[RestAPIExecutionConfig] = None remote_execution: Optional[RestAPIExecutionConfig] = None

View file

@ -6,11 +6,12 @@
import asyncio import asyncio
import json import json
import os
from typing import AsyncGenerator from typing import AsyncGenerator
import fire import fire
import httpx import httpx
from dotenv import load_dotenv
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
@ -22,6 +23,9 @@ from .agents import * # noqa: F403
from .event_logger import EventLogger from .event_logger import EventLogger
load_dotenv()
async def get_client_impl(config: RemoteProviderConfig, _deps): async def get_client_impl(config: RemoteProviderConfig, _deps):
return AgentsClient(config.url) return AgentsClient(config.url)
@ -129,8 +133,11 @@ async def run_main(host: str, port: int):
api = AgentsClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
SearchToolDefinition(engine=SearchEngineType.bing), SearchToolDefinition(
WolframAlphaToolDefinition(), engine=SearchEngineType.bing,
api_key=os.getenv("BING_SEARCH_API_KEY"),
),
WolframAlphaToolDefinition(api_key=os.getenv("WOLFRAM_ALPHA_API_KEY")),
CodeInterpreterToolDefinition(), CodeInterpreterToolDefinition(),
] ]
tool_definitions += [ tool_definitions += [

View file

@ -58,21 +58,9 @@ class MetaReferenceAgentsImpl(Agents):
builtin_tools = [] builtin_tools = []
for tool_defn in agent_config.tools: for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition): if isinstance(tool_defn, WolframAlphaToolDefinition):
key = self.config.wolfram_api_key tool = WolframAlphaTool(tool_defn.api_key)
if not key:
raise ValueError("Wolfram API key not defined in config")
tool = WolframAlphaTool(key)
elif isinstance(tool_defn, SearchToolDefinition): elif isinstance(tool_defn, SearchToolDefinition):
key = None tool = SearchTool(tool_defn.engine, tool_defn.api_key)
if tool_defn.engine == SearchEngineType.brave:
key = self.config.brave_search_api_key
elif tool_defn.engine == SearchEngineType.bing:
key = self.config.bing_search_api_key
if not key:
raise ValueError(
"Search (Brave or Bing) API key not defined in config"
)
tool = SearchTool(tool_defn.engine, key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition): elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool() tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition): elif isinstance(tool_defn, PhotogenToolDefinition):

View file

@ -4,12 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
class MetaReferenceImplConfig(BaseModel): class MetaReferenceImplConfig(BaseModel): ...
brave_search_api_key: Optional[str] = None
bing_search_api_key: Optional[str] = None
wolfram_api_key: Optional[str] = None

View file

@ -3,6 +3,7 @@ fire
httpx httpx
huggingface-hub huggingface-hub
llama-models>=0.0.16 llama-models>=0.0.16
python-dotenv
pydantic pydantic
requests requests
termcolor termcolor