mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Fireworks agent client
This commit is contained in:
parent
8a3b64d1be
commit
21ede64d4b
1 changed files with 30 additions and 8 deletions
|
@ -20,8 +20,9 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
from .agents import * # noqa: F403
|
||||
from .event_logger import EventLogger
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from .event_logger import EventLogger
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -35,13 +36,21 @@ def encodable_dict(d: BaseModel):
|
|||
|
||||
|
||||
class AgentsClient(Agents):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
def __init__(self, base_url: str, port: int):
|
||||
# Check if base_url is a full URL
|
||||
parsed_url = urlparse(base_url)
|
||||
if parsed_url.scheme:
|
||||
# If it's a full URL, use it as is
|
||||
self.base_url = base_url.rstrip("/")
|
||||
else:
|
||||
# If it's just a hostname, construct the URL with port
|
||||
self.base_url = f"http://{base_url}:{port}"
|
||||
|
||||
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = urljoin(self.base_url, "/agents/create")
|
||||
response = await client.post(
|
||||
f"{self.base_url}/agents/create",
|
||||
url,
|
||||
json={
|
||||
"agent_config": encodable_dict(agent_config),
|
||||
},
|
||||
|
@ -56,8 +65,9 @@ class AgentsClient(Agents):
|
|||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = urljoin(self.base_url, "/agents/session/create")
|
||||
response = await client.post(
|
||||
f"{self.base_url}/agents/session/create",
|
||||
url,
|
||||
json={
|
||||
"agent_id": agent_id,
|
||||
"session_name": session_name,
|
||||
|
@ -144,7 +154,7 @@ async def _run_agent(
|
|||
|
||||
|
||||
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
api = AgentsClient(host, port)
|
||||
|
||||
tool_definitions = [
|
||||
SearchToolDefinition(
|
||||
|
@ -184,7 +194,7 @@ async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct
|
|||
|
||||
|
||||
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
api = AgentsClient(host, port)
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
|
@ -225,7 +235,7 @@ async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Inst
|
|||
|
||||
|
||||
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
api = AgentsClient(host, port)
|
||||
|
||||
# zero shot tools for llama3.2 text models
|
||||
tool_definitions = [
|
||||
|
@ -271,12 +281,24 @@ async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct
|
|||
|
||||
|
||||
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
|
||||
"""
|
||||
Main function that handles both hostname and full URL cases while keeping port required
|
||||
@param host: Hostname or full URL
|
||||
@param port: Port number (required)
|
||||
@param run_type: Type of run (tools_llama_3_1, tools_llama_3_2, rag_llama_3_2)
|
||||
@param model: Optional model name
|
||||
"""
|
||||
assert run_type in [
|
||||
"tools_llama_3_1",
|
||||
"tools_llama_3_2",
|
||||
"rag_llama_3_2",
|
||||
], f"Invalid run type {run_type}, must be one of tools_llama_3_1, tools_llama_3_2, rag_llama_3_2"
|
||||
|
||||
# Extract hostname if a full URL is provided
|
||||
if host.startswith(("http://", "https://")):
|
||||
# Keep the scheme (http:// or https://) but remove any trailing slashes
|
||||
host = host.rstrip("/")
|
||||
|
||||
fn = {
|
||||
"tools_llama_3_1": run_llama_3_1,
|
||||
"tools_llama_3_2": run_llama_3_2,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue