Support downloading of URLs for attachments for code interpreter

This commit is contained in:
Ashwin Bharambe 2024-08-30 12:10:15 -07:00
parent afb18880b5
commit 297d51b183
6 changed files with 38 additions and 9 deletions

View file

@ -32,7 +32,6 @@ def encodable_dict(d: BaseModel):
class AgenticSystemClient(AgenticSystem):
def __init__(self, base_url: str):
print(f"Agentic System passthrough to -> {base_url}")
self.base_url = base_url
async def create_agentic_system(

View file

@ -6,9 +6,17 @@
import asyncio
import copy
import os
import secrets
import shutil
import string
import tempfile
import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Tuple
from urllib.parse import urlparse
import httpx
from termcolor import cprint
@ -26,6 +34,12 @@ from llama_toolchain.tools.builtin import (
from .safety import SafetyException, ShieldRunnerMixin
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
@ -44,6 +58,7 @@ class ChatAgent(ShieldRunnerMixin):
self.max_infer_iters = max_infer_iters
self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.tempdir = tempfile.mkdtemp()
self.sessions = {}
ShieldRunnerMixin.__init__(
@ -53,6 +68,9 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields,
)
def __del__(self):
shutil.rmtree(self.tempdir)
def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = []
@ -343,7 +361,8 @@ class ChatAgent(ShieldRunnerMixin):
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
input_messages.append(attachment_message(urls))
msg = await attachment_message(self.tempdir, urls)
input_messages.append(msg)
output_attachments = []
@ -707,13 +726,27 @@ class ChatAgent(ShieldRunnerMixin):
return ret
def attachment_message(urls: List[URL]) -> ToolResponseMessage:
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
content = []
for url in urls:
uri = url.uri
assert uri.startswith("file://")
filepath = uri[len("file://") :]
if uri.startswith("file://"):
filepath = uri[len("file://") :]
elif uri.startswith("http"):
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
print(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
with open(filepath, "w") as fp:
fp.write(resp)
else:
raise ValueError(f"Unsupported URL {url}")
content.append(f'# There is a file accessible to you at "{filepath}"\n')
return ToolResponseMessage(

View file

@ -304,7 +304,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
and provider_spec.adapter is None
):
for endpoint in endpoints:
url = impl.__provider_config__.url
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)

View file

@ -36,7 +36,6 @@ def encodable_dict(d: BaseModel):
class InferenceClient(Inference):
def __init__(self, base_url: str):
print(f"Inference passthrough to -> {base_url}")
self.base_url = base_url
async def initialize(self) -> None:

View file

@ -22,7 +22,6 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
class MemoryClient(Memory):
def __init__(self, base_url: str):
print(f"Memory passthrough to -> {base_url}")
self.base_url = base_url
async def initialize(self) -> None:

View file

@ -31,7 +31,6 @@ def encodable_dict(d: BaseModel):
class SafetyClient(Safety):
def __init__(self, base_url: str):
print(f"Safety passthrough to -> {base_url}")
self.base_url = base_url
async def initialize(self) -> None: