From 297d51b1839a39bde0affed43cdba60cfaeea70a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 30 Aug 2024 12:10:15 -0700 Subject: [PATCH] Support downloading of URLs for attachments for code interpreter --- llama_toolchain/agentic_system/client.py | 1 - .../meta_reference/agent_instance.py | 41 +++++++++++++++++-- llama_toolchain/core/server.py | 2 +- llama_toolchain/inference/client.py | 1 - llama_toolchain/memory/client.py | 1 - llama_toolchain/safety/client.py | 1 - 6 files changed, 38 insertions(+), 9 deletions(-) diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index e73d8b70e..fadb78182 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -32,7 +32,6 @@ def encodable_dict(d: BaseModel): class AgenticSystemClient(AgenticSystem): def __init__(self, base_url: str): - print(f"Agentic System passthrough to -> {base_url}") self.base_url = base_url async def create_agentic_system( diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 1c75f0b83..ed3145b1e 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -6,9 +6,17 @@ import asyncio import copy +import os +import secrets +import shutil +import string +import tempfile import uuid from datetime import datetime from typing import AsyncGenerator, List, Tuple +from urllib.parse import urlparse + +import httpx from termcolor import cprint @@ -26,6 +34,12 @@ from llama_toolchain.tools.builtin import ( from .safety import SafetyException, ShieldRunnerMixin +def make_random_string(length: int = 8): + return "".join( + secrets.choice(string.ascii_letters + string.digits) for _ in range(length) + ) + + class ChatAgent(ShieldRunnerMixin): def __init__( self, @@ -44,6 +58,7 @@ class ChatAgent(ShieldRunnerMixin): self.max_infer_iters = max_infer_iters self.tools_dict = {t.get_name(): t for t in builtin_tools} + self.tempdir = tempfile.mkdtemp() self.sessions = {} ShieldRunnerMixin.__init__( @@ -53,6 +68,9 @@ class ChatAgent(ShieldRunnerMixin): output_shields=agent_config.output_shields, ) + def __del__(self): + shutil.rmtree(self.tempdir) + def turn_to_messages(self, turn: Turn) -> List[Message]: messages = [] @@ -343,7 +361,8 @@ class ChatAgent(ShieldRunnerMixin): elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools: urls = [a.content for a in attachments if isinstance(a.content, URL)] - input_messages.append(attachment_message(urls)) + msg = await attachment_message(self.tempdir, urls) + input_messages.append(msg) output_attachments = [] @@ -707,13 +726,27 @@ class ChatAgent(ShieldRunnerMixin): return ret -def attachment_message(urls: List[URL]) -> ToolResponseMessage: +async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage: content = [] for url in urls: uri = url.uri - assert uri.startswith("file://") - filepath = uri[len("file://") :] + if uri.startswith("file://"): + filepath = uri[len("file://") :] + elif uri.startswith("http"): + path = urlparse(uri).path + basename = os.path.basename(path) + filepath = f"{tempdir}/{make_random_string() + basename}" + print(f"Downloading {url} -> {filepath}") + + async with httpx.AsyncClient() as client: + r = await client.get(uri) + resp = r.text + with open(filepath, "w") as fp: + fp.write(resp) + else: + raise ValueError(f"Unsupported URL {url}") + content.append(f'# There is a file accessible to you at "{filepath}"\n') return ToolResponseMessage( diff --git a/llama_toolchain/core/server.py b/llama_toolchain/core/server.py index cf290b951..4de84b726 100644 --- a/llama_toolchain/core/server.py +++ b/llama_toolchain/core/server.py @@ -304,7 +304,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): and provider_spec.adapter is None ): for endpoint in endpoints: - url = impl.__provider_config__.url + url = impl.__provider_config__.url.rstrip("/") + endpoint.route getattr(app, endpoint.method)(endpoint.route)( create_dynamic_passthrough(url) ) diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index e90d9c86c..5ba9314bc 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -36,7 +36,6 @@ def encodable_dict(d: BaseModel): class InferenceClient(Inference): def __init__(self, base_url: str): - print(f"Inference passthrough to -> {base_url}") self.base_url = base_url async def initialize(self) -> None: diff --git a/llama_toolchain/memory/client.py b/llama_toolchain/memory/client.py index b2d9ab656..4401276fa 100644 --- a/llama_toolchain/memory/client.py +++ b/llama_toolchain/memory/client.py @@ -22,7 +22,6 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory: class MemoryClient(Memory): def __init__(self, base_url: str): - print(f"Memory passthrough to -> {base_url}") self.base_url = base_url async def initialize(self) -> None: diff --git a/llama_toolchain/safety/client.py b/llama_toolchain/safety/client.py index 73c5682e3..0cf7deae8 100644 --- a/llama_toolchain/safety/client.py +++ b/llama_toolchain/safety/client.py @@ -31,7 +31,6 @@ def encodable_dict(d: BaseModel): class SafetyClient(Safety): def __init__(self, base_url: str): - print(f"Safety passthrough to -> {base_url}") self.base_url = base_url async def initialize(self) -> None: