mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Fireworks agent client
This commit is contained in:
parent
8a3b64d1be
commit
21ede64d4b
1 changed files with 30 additions and 8 deletions
|
@ -20,8 +20,9 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
|
|
||||||
from .agents import * # noqa: F403
|
from .agents import * # noqa: F403
|
||||||
from .event_logger import EventLogger
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
@ -35,13 +36,21 @@ def encodable_dict(d: BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class AgentsClient(Agents):
|
class AgentsClient(Agents):
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str, port: int):
|
||||||
self.base_url = base_url
|
# Check if base_url is a full URL
|
||||||
|
parsed_url = urlparse(base_url)
|
||||||
|
if parsed_url.scheme:
|
||||||
|
# If it's a full URL, use it as is
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
else:
|
||||||
|
# If it's just a hostname, construct the URL with port
|
||||||
|
self.base_url = f"http://{base_url}:{port}"
|
||||||
|
|
||||||
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
|
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
url = urljoin(self.base_url, "/agents/create")
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/agents/create",
|
url,
|
||||||
json={
|
json={
|
||||||
"agent_config": encodable_dict(agent_config),
|
"agent_config": encodable_dict(agent_config),
|
||||||
},
|
},
|
||||||
|
@ -56,8 +65,9 @@ class AgentsClient(Agents):
|
||||||
session_name: str,
|
session_name: str,
|
||||||
) -> AgentSessionCreateResponse:
|
) -> AgentSessionCreateResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
url = urljoin(self.base_url, "/agents/session/create")
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/agents/session/create",
|
url,
|
||||||
json={
|
json={
|
||||||
"agent_id": agent_id,
|
"agent_id": agent_id,
|
||||||
"session_name": session_name,
|
"session_name": session_name,
|
||||||
|
@ -144,7 +154,7 @@ async def _run_agent(
|
||||||
|
|
||||||
|
|
||||||
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
||||||
api = AgentsClient(f"http://{host}:{port}")
|
api = AgentsClient(host, port)
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
SearchToolDefinition(
|
SearchToolDefinition(
|
||||||
|
@ -184,7 +194,7 @@ async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct
|
||||||
|
|
||||||
|
|
||||||
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||||
api = AgentsClient(f"http://{host}:{port}")
|
api = AgentsClient(host, port)
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
|
@ -225,7 +235,7 @@ async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Inst
|
||||||
|
|
||||||
|
|
||||||
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||||
api = AgentsClient(f"http://{host}:{port}")
|
api = AgentsClient(host, port)
|
||||||
|
|
||||||
# zero shot tools for llama3.2 text models
|
# zero shot tools for llama3.2 text models
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
|
@ -271,12 +281,24 @@ async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
|
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Main function that handles both hostname and full URL cases while keeping port required
|
||||||
|
@param host: Hostname or full URL
|
||||||
|
@param port: Port number (required)
|
||||||
|
@param run_type: Type of run (tools_llama_3_1, tools_llama_3_2, rag_llama_3_2)
|
||||||
|
@param model: Optional model name
|
||||||
|
"""
|
||||||
assert run_type in [
|
assert run_type in [
|
||||||
"tools_llama_3_1",
|
"tools_llama_3_1",
|
||||||
"tools_llama_3_2",
|
"tools_llama_3_2",
|
||||||
"rag_llama_3_2",
|
"rag_llama_3_2",
|
||||||
], f"Invalid run type {run_type}, must be one of tools_llama_3_1, tools_llama_3_2, rag_llama_3_2"
|
], f"Invalid run type {run_type}, must be one of tools_llama_3_1, tools_llama_3_2, rag_llama_3_2"
|
||||||
|
|
||||||
|
# Extract hostname if a full URL is provided
|
||||||
|
if host.startswith(("http://", "https://")):
|
||||||
|
# Keep the scheme (http:// or https://) but remove any trailing slashes
|
||||||
|
host = host.rstrip("/")
|
||||||
|
|
||||||
fn = {
|
fn = {
|
||||||
"tools_llama_3_1": run_llama_3_1,
|
"tools_llama_3_1": run_llama_3_1,
|
||||||
"tools_llama_3_2": run_llama_3_2,
|
"tools_llama_3_2": run_llama_3_2,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue