Fireworks agent client

This commit is contained in:
benjibc 2024-10-29 18:34:52 +00:00
parent 8a3b64d1be
commit 21ede64d4b

View file

@ -20,8 +20,9 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .agents import * # noqa: F403
from .event_logger import EventLogger
from urllib.parse import urljoin, urlparse
from .event_logger import EventLogger
load_dotenv()
@ -35,13 +36,21 @@ def encodable_dict(d: BaseModel):
class AgentsClient(Agents):
def __init__(self, base_url: str):
self.base_url = base_url
def __init__(self, base_url: str, port: int):
# Check if base_url is a full URL
parsed_url = urlparse(base_url)
if parsed_url.scheme:
# If it's a full URL, use it as is
self.base_url = base_url.rstrip("/")
else:
# If it's just a hostname, construct the URL with port
self.base_url = f"http://{base_url}:{port}"
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
async with httpx.AsyncClient() as client:
url = urljoin(self.base_url, "/agents/create")
response = await client.post(
f"{self.base_url}/agents/create",
url,
json={
"agent_config": encodable_dict(agent_config),
},
@ -56,8 +65,9 @@ class AgentsClient(Agents):
session_name: str,
) -> AgentSessionCreateResponse:
async with httpx.AsyncClient() as client:
url = urljoin(self.base_url, "/agents/session/create")
response = await client.post(
f"{self.base_url}/agents/session/create",
url,
json={
"agent_id": agent_id,
"session_name": session_name,
@ -144,7 +154,7 @@ async def _run_agent(
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
api = AgentsClient(host, port)
tool_definitions = [
SearchToolDefinition(
@ -184,7 +194,7 @@ async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
api = AgentsClient(host, port)
urls = [
"memory_optimizations.rst",
@ -225,7 +235,7 @@ async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Inst
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
api = AgentsClient(host, port)
# zero shot tools for llama3.2 text models
tool_definitions = [
@ -271,12 +281,24 @@ async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
"""
Main function that handles both hostname and full URL cases while keeping port required
@param host: Hostname or full URL
@param port: Port number (required)
@param run_type: Type of run (tools_llama_3_1, tools_llama_3_2, rag_llama_3_2)
@param model: Optional model name
"""
assert run_type in [
"tools_llama_3_1",
"tools_llama_3_2",
"rag_llama_3_2",
], f"Invalid run type {run_type}, must be one of tools_llama_3_1, tools_llama_3_2, rag_llama_3_2"
# Extract hostname if a full URL is provided
if host.startswith(("http://", "https://")):
# Keep the scheme (http:// or https://) but remove any trailing slashes
host = host.rstrip("/")
fn = {
"tools_llama_3_1": run_llama_3_1,
"tools_llama_3_2": run_llama_3_2,