mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Support downloading of URLs for attachments for code interpreter
This commit is contained in:
parent
afb18880b5
commit
297d51b183
6 changed files with 38 additions and 9 deletions
|
@ -32,7 +32,6 @@ def encodable_dict(d: BaseModel):
|
|||
|
||||
class AgenticSystemClient(AgenticSystem):
|
||||
def __init__(self, base_url: str):
|
||||
print(f"Agentic System passthrough to -> {base_url}")
|
||||
self.base_url = base_url
|
||||
|
||||
async def create_agentic_system(
|
||||
|
|
|
@ -6,9 +6,17 @@
|
|||
|
||||
import asyncio
|
||||
import copy
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import string
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
@ -26,6 +34,12 @@ from llama_toolchain.tools.builtin import (
|
|||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||
)
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -44,6 +58,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.max_infer_iters = max_infer_iters
|
||||
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
||||
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
self.sessions = {}
|
||||
|
||||
ShieldRunnerMixin.__init__(
|
||||
|
@ -53,6 +68,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
shutil.rmtree(self.tempdir)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||
messages = []
|
||||
|
||||
|
@ -343,7 +361,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||
input_messages.append(attachment_message(urls))
|
||||
msg = await attachment_message(self.tempdir, urls)
|
||||
input_messages.append(msg)
|
||||
|
||||
output_attachments = []
|
||||
|
||||
|
@ -707,13 +726,27 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return ret
|
||||
|
||||
|
||||
def attachment_message(urls: List[URL]) -> ToolResponseMessage:
|
||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||
content = []
|
||||
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
assert uri.startswith("file://")
|
||||
if uri.startswith("file://"):
|
||||
filepath = uri[len("file://") :]
|
||||
elif uri.startswith("http"):
|
||||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
print(f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
resp = r.text
|
||||
with open(filepath, "w") as fp:
|
||||
fp.write(resp)
|
||||
else:
|
||||
raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
content.append(f'# There is a file accessible to you at "{filepath}"\n')
|
||||
|
||||
return ToolResponseMessage(
|
||||
|
|
|
@ -304,7 +304,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
and provider_spec.adapter is None
|
||||
):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
|
|
|
@ -36,7 +36,6 @@ def encodable_dict(d: BaseModel):
|
|||
|
||||
class InferenceClient(Inference):
|
||||
def __init__(self, base_url: str):
|
||||
print(f"Inference passthrough to -> {base_url}")
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -22,7 +22,6 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
|
|||
|
||||
class MemoryClient(Memory):
|
||||
def __init__(self, base_url: str):
|
||||
print(f"Memory passthrough to -> {base_url}")
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -31,7 +31,6 @@ def encodable_dict(d: BaseModel):
|
|||
|
||||
class SafetyClient(Safety):
|
||||
def __init__(self, base_url: str):
|
||||
print(f"Safety passthrough to -> {base_url}")
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue