From 21ede64d4b70cad4b1b1aef9e076efc5e1eca8e2 Mon Sep 17 00:00:00 2001 From: benjibc Date: Tue, 29 Oct 2024 18:34:52 +0000 Subject: [PATCH] Fireworks agent client --- llama_stack/apis/agents/client.py | 38 ++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/llama_stack/apis/agents/client.py b/llama_stack/apis/agents/client.py index b45447328..61b1faac7 100644 --- a/llama_stack/apis/agents/client.py +++ b/llama_stack/apis/agents/client.py @@ -20,8 +20,9 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import RemoteProviderConfig from .agents import * # noqa: F403 -from .event_logger import EventLogger +from urllib.parse import urljoin, urlparse +from .event_logger import EventLogger load_dotenv() @@ -35,13 +36,21 @@ def encodable_dict(d: BaseModel): class AgentsClient(Agents): - def __init__(self, base_url: str): - self.base_url = base_url + def __init__(self, base_url: str, port: int): + # Check if base_url is a full URL + parsed_url = urlparse(base_url) + if parsed_url.scheme: + # If it's a full URL, use it as is + self.base_url = base_url.rstrip("/") + else: + # If it's just a hostname, construct the URL with port + self.base_url = f"http://{base_url}:{port}" async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse: async with httpx.AsyncClient() as client: + url = urljoin(self.base_url, "/agents/create") response = await client.post( - f"{self.base_url}/agents/create", + url, json={ "agent_config": encodable_dict(agent_config), }, @@ -56,8 +65,9 @@ class AgentsClient(Agents): session_name: str, ) -> AgentSessionCreateResponse: async with httpx.AsyncClient() as client: + url = urljoin(self.base_url, "/agents/session/create") response = await client.post( - f"{self.base_url}/agents/session/create", + url, json={ "agent_id": agent_id, "session_name": session_name, @@ -144,7 +154,7 @@ async def _run_agent( async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"): - api = AgentsClient(f"http://{host}:{port}") + api = AgentsClient(host, port) tool_definitions = [ SearchToolDefinition( @@ -184,7 +194,7 @@ async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"): - api = AgentsClient(f"http://{host}:{port}") + api = AgentsClient(host, port) urls = [ "memory_optimizations.rst", @@ -225,7 +235,7 @@ async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Inst async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"): - api = AgentsClient(f"http://{host}:{port}") + api = AgentsClient(host, port) # zero shot tools for llama3.2 text models tool_definitions = [ @@ -271,12 +281,24 @@ async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct def main(host: str, port: int, run_type: str, model: Optional[str] = None): + """ + Main function that handles both hostname and full URL cases while keeping port required + @param host: Hostname or full URL + @param port: Port number (required) + @param run_type: Type of run (tools_llama_3_1, tools_llama_3_2, rag_llama_3_2) + @param model: Optional model name + """ assert run_type in [ "tools_llama_3_1", "tools_llama_3_2", "rag_llama_3_2", ], f"Invalid run type {run_type}, must be one of tools_llama_3_1, tools_llama_3_2, rag_llama_3_2" + # Extract hostname if a full URL is provided + if host.startswith(("http://", "https://")): + # Keep the scheme (http:// or https://) but remove any trailing slashes + host = host.rstrip("/") + fn = { "tools_llama_3_1": run_llama_3_1, "tools_llama_3_2": run_llama_3_2,