re-work tool definitions, fix FastAPI issues, fix tool regressions

This commit is contained in:
Ashwin Bharambe 2024-08-24 22:07:06 -07:00
parent 8d14d4228b
commit 8efe614719
11 changed files with 144 additions and 104 deletions

View file

@ -6,49 +6,42 @@
import asyncio
import json
from typing import AsyncGenerator
import fire
import httpx
from llama_models.llama3.api.datatypes import (
BuiltinTool,
SamplingParams,
ToolParamDefinition,
ToolPromptFormat,
UserMessage,
)
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.agentic_system.event_logger import EventLogger
from .api import (
AgentConfig,
AgenticSystem,
AgenticSystemCreateResponse,
AgenticSystemSessionCreateResponse,
AgenticSystemToolDefinition,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk,
)
from llama_models.llama3.api.datatypes import * # noqa: F403
from .api import * # noqa: F403
from .event_logger import EventLogger
async def get_client_impl(base_url: str):
return AgenticSystemClient(base_url)
def encodable_dict(d: BaseModel):
return json.loads(d.json())
class AgenticSystemClient(AgenticSystem):
def __init__(self, base_url: str):
self.base_url = base_url
async def create_agentic_system(
self, request: AgenticSystemCreateRequest
self, agent_config: AgentConfig
) -> AgenticSystemCreateResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/agentic_system/create",
data=request.json(),
json={
"agent_config": encodable_dict(agent_config),
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
@ -56,12 +49,16 @@ class AgenticSystemClient(AgenticSystem):
async def create_agentic_system_session(
self,
request: AgenticSystemSessionCreateRequest,
agent_id: str,
session_name: str,
) -> AgenticSystemSessionCreateResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/agentic_system/session/create",
data=request.json(),
json={
"agent_id": agent_id,
"session_name": session_name,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
@ -75,7 +72,9 @@ class AgenticSystemClient(AgenticSystem):
async with client.stream(
"POST",
f"{self.base_url}/agentic_system/turn/create",
data=request.json(),
json={
"request": encodable_dict(request),
},
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
@ -96,19 +95,13 @@ async def run_main(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [
AgenticSystemToolDefinition(
tool_name=BuiltinTool.brave_search,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.wolfram_alpha,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.code_interpreter,
),
BraveSearchToolDefinition(),
WolframAlphaToolDefinition(),
CodeInterpreterToolDefinition(),
]
tool_definitions += [
AgenticSystemToolDefinition(
tool_name="get_boiling_point",
FunctionCallToolDefinition(
function_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={
"liquid_name": ToolParamDefinition(
@ -128,12 +121,10 @@ async def run_main(host: str, port: int):
agent_config = AgentConfig(
model="Meta-Llama3.1-8B-Instruct",
instructions="You are a helpful assistant",
sampling_params=SamplingParams(),
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
tools=tool_definitions,
input_shields=[],
output_shields=[],
debug_prefix_messages=[],
tool_prompt_format=ToolPromptFormat.json,
tool_choice=ToolChoice.auto,
tool_prompt_format=ToolPromptFormat.function_tag,
)
create_response = await api.create_agentic_system(agent_config)