From ea6a4a14cea6608853c547a5ea28a7c6d763e6bf Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 20 Mar 2025 10:15:49 -0700 Subject: [PATCH] feat(api): simplify client imports (#1687) # What does this PR do? closes #1554 ## Test Plan test_agents.py --- docs/getting_started.ipynb | 29 ++++++++----------- .../Llama_Stack_Agent_Workflows.ipynb | 3 +- .../notebooks/Llama_Stack_RAG_Lifecycle.ipynb | 4 +-- docs/source/building_applications/agent.md | 6 ++-- .../agent_execution_loop.md | 6 ++-- docs/source/building_applications/evals.md | 6 ++-- docs/source/building_applications/rag.md | 10 +++---- docs/source/building_applications/tools.md | 2 +- docs/source/getting_started/index.md | 8 ++--- .../distribution/ui/page/playground/rag.py | 8 ++--- tests/integration/agents/test_agents.py | 16 +++++----- 11 files changed, 40 insertions(+), 58 deletions(-) diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index fd625a394..e361be277 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -1203,7 +1203,7 @@ } ], "source": [ - "from llama_stack_client.lib.inference.event_logger import EventLogger\n", + "from llama_stack_client import InferenceEventLogger\n", "\n", "message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n", "print(f'User> {message[\"content\"]}', \"green\")\n", @@ -1215,7 +1215,7 @@ ")\n", "\n", "# Print the tokens while they are received\n", - "for log in EventLogger().log(response):\n", + "for log in InferenceEventLogger().log(response):\n", " log.print()\n" ] }, @@ -1632,8 +1632,7 @@ } ], "source": [ - "from llama_stack_client.lib.agents.agent import Agent\n", - "from llama_stack_client.lib.agents.event_logger import EventLogger\n", + "from llama_stack_client import Agent, AgentEventLogger\n", "from termcolor import cprint\n", "\n", "agent = Agent(\n", @@ -1659,7 +1658,7 @@ " ],\n", " session_id=session_id,\n", " )\n", - " for log in EventLogger().log(response):\n", + " for log in AgentEventLogger().log(response):\n", " log.print()\n" ] }, @@ -1808,14 +1807,12 @@ ], "source": [ "import uuid\n", - "from llama_stack_client.lib.agents.agent import Agent\n", - "from llama_stack_client.lib.agents.event_logger import EventLogger\n", + "from llama_stack_client import Agent, AgentEventLogger, RAGDocument\n", "from termcolor import cprint\n", - "from llama_stack_client.types import Document\n", "\n", "urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n", "documents = [\n", - " Document(\n", + " RAGDocument(\n", " document_id=f\"num-{i}\",\n", " content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n", " mime_type=\"text/plain\",\n", @@ -1858,7 +1855,7 @@ " messages=[{\"role\": \"user\", \"content\": prompt}],\n", " session_id=session_id,\n", " )\n", - " for log in EventLogger().log(response):\n", + " for log in AgentEventLogger().log(response):\n", " log.print()" ] }, @@ -1969,7 +1966,7 @@ } ], "source": [ - "from llama_stack_client.types.agents.turn_create_params import Document\n", + "from llama_stack_client import Document\n", "\n", "codex_agent = Agent(\n", " client, \n", @@ -2891,8 +2888,7 @@ ], "source": [ "# NBVAL_SKIP\n", - "from llama_stack_client.lib.agents.agent import Agent\n", - "from llama_stack_client.lib.agents.event_logger import EventLogger\n", + "from llama_stack_client import Agent, AgentEventLogger\n", "from termcolor import cprint\n", "\n", "agent = Agent(\n", @@ -2918,7 +2914,7 @@ " ],\n", " session_id=session_id,\n", " )\n", - " for log in EventLogger().log(response):\n", + " for log in AgentEventLogger().log(response):\n", " log.print()\n" ] }, @@ -2993,8 +2989,7 @@ } ], "source": [ - "from llama_stack_client.lib.agents.agent import Agent\n", - "from llama_stack_client.lib.agents.event_logger import EventLogger\n", + "from llama_stack_client import Agent, AgentEventLogger\n", "\n", "agent = Agent(\n", " client, \n", @@ -3021,7 +3016,7 @@ " session_id=session_id,\n", " )\n", "\n", - " for log in EventLogger().log(response):\n", + " for log in AgentEventLogger().log(response):\n", " log.print()\n" ] }, diff --git a/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb b/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb index f800fb1d4..cad28ab82 100644 --- a/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb +++ b/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb @@ -47,9 +47,8 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_stack_client import LlamaStackClient\n", + "from llama_stack_client import LlamaStackClient, Agent\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", - "from llama_stack_client.lib.agents.agent import Agent\n", "from rich.pretty import pprint\n", "import json\n", "import uuid\n", diff --git a/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb b/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb index 0d7b462cc..36d28dd16 100644 --- a/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb +++ b/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb @@ -34,10 +34,8 @@ } ], "source": [ - "from llama_stack_client import LlamaStackClient\n", + "from llama_stack_client import LlamaStackClient, Agent\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", - "from llama_stack_client.types.agent_create_params import AgentConfig\n", - "from llama_stack_client.lib.agents.agent import Agent\n", "from rich.pretty import pprint\n", "import json\n", "import uuid\n", diff --git a/docs/source/building_applications/agent.md b/docs/source/building_applications/agent.md index 3836ab701..283fb45e4 100644 --- a/docs/source/building_applications/agent.md +++ b/docs/source/building_applications/agent.md @@ -14,7 +14,7 @@ Agents are configured using the `AgentConfig` class, which includes: - **Safety Shields**: Guardrails to ensure responsible AI behavior ```python -from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client import Agent # Create the agent @@ -44,14 +44,14 @@ Each interaction with an agent is called a "turn" and consists of: - **Output Message**: The agent's response ```python -from llama_stack_client.lib.agents.event_logger import EventLogger +from llama_stack_client import AgentEventLogger # Create a turn with streaming response turn_response = agent.create_turn( session_id=session_id, messages=[{"role": "user", "content": "Tell me about Llama models"}], ) -for log in EventLogger().log(turn_response): +for log in AgentEventLogger().log(turn_response): log.print() ``` ### Non-Streaming diff --git a/docs/source/building_applications/agent_execution_loop.md b/docs/source/building_applications/agent_execution_loop.md index eebaccc66..a180602c6 100644 --- a/docs/source/building_applications/agent_execution_loop.md +++ b/docs/source/building_applications/agent_execution_loop.md @@ -67,9 +67,7 @@ sequenceDiagram Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution: ```python -from llama_stack_client import LlamaStackClient -from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.event_logger import EventLogger +from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger from rich.pretty import pprint # Replace host and port @@ -113,7 +111,7 @@ response = agent.create_turn( ) # Monitor each step of execution -for log in EventLogger().log(response): +for log in AgentEventLogger().log(response): log.print() # Using non-streaming API, the response contains input, steps, and output. diff --git a/docs/source/building_applications/evals.md b/docs/source/building_applications/evals.md index 211d3bc26..ded62cebb 100644 --- a/docs/source/building_applications/evals.md +++ b/docs/source/building_applications/evals.md @@ -23,9 +23,7 @@ In this example, we will show you how to: ##### Building a Search Agent ```python -from llama_stack_client import LlamaStackClient -from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.event_logger import EventLogger +from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}") @@ -54,7 +52,7 @@ for prompt in user_prompts: session_id=session_id, ) - for log in EventLogger().log(response): + for log in AgentEventLogger().log(response): log.print() ``` diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index e39ec0d5e..c3d02d7dc 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -55,11 +55,11 @@ chunks_response = client.vector_io.query( A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces. ```python -from llama_stack_client.types import Document +from llama_stack_client import RAGDocument urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"] documents = [ - Document( + RAGDocument( document_id=f"num-{i}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", mime_type="text/plain", @@ -86,7 +86,7 @@ results = client.tool_runtime.rag_tool.query( One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example: ```python -from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client import Agent # Create agent with memory agent = Agent( @@ -140,9 +140,9 @@ response = agent.create_turn( You can print the response with below. ```python -from llama_stack_client.lib.agents.event_logger import EventLogger +from llama_stack_client import AgentEventLogger -for log in EventLogger().log(response): +for log in AgentEventLogger().log(response): log.print() ``` diff --git a/docs/source/building_applications/tools.md b/docs/source/building_applications/tools.md index d5354a3da..94841a773 100644 --- a/docs/source/building_applications/tools.md +++ b/docs/source/building_applications/tools.md @@ -189,7 +189,7 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools") ## Simple Example: Using an Agent with the Code-Interpreter Tool ```python -from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client import Agent # Instantiate the AI agent with the given configuration agent = Agent( diff --git a/docs/source/getting_started/index.md b/docs/source/getting_started/index.md index 7e4446393..f846c9ff0 100644 --- a/docs/source/getting_started/index.md +++ b/docs/source/getting_started/index.md @@ -197,9 +197,7 @@ import os import uuid from termcolor import cprint -from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types import Document +from llama_stack_client import Agent, AgentEventLogger, RAGDocument def create_http_client(): @@ -225,7 +223,7 @@ client = ( # Documents to be used for RAG urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [ - Document( + RAGDocument( document_id=f"num-{i}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", mime_type="text/plain", @@ -284,7 +282,7 @@ for prompt in user_prompts: messages=[{"role": "user", "content": prompt}], session_id=session_id, ) - for log in EventLogger().log(response): + for log in AgentEventLogger().log(response): log.print() ``` diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index e2f451668..fded229c4 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -5,9 +5,7 @@ # the root directory of this source tree. import streamlit as st -from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types.shared.document import Document +from llama_stack_client import Agent, AgentEventLogger, RAGDocument from llama_stack.distribution.ui.modules.api import llama_stack_api from llama_stack.distribution.ui.modules.utils import data_url_from_file @@ -35,7 +33,7 @@ def rag_chat_page(): ) if st.button("Create Vector Database"): documents = [ - Document( + RAGDocument( document_id=uploaded_file.name, content=data_url_from_file(uploaded_file), ) @@ -167,7 +165,7 @@ def rag_chat_page(): message_placeholder = st.empty() full_response = "" retrieval_response = "" - for log in EventLogger().log(response): + for log in AgentEventLogger().log(response): log.print() if log.role == "tool_execution": retrieval_response += log.content.replace("====", "").strip() diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 581cc9f45..7011dc02d 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -8,9 +8,7 @@ from typing import Any, Dict from uuid import uuid4 import pytest -from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types.agents.turn_create_params import Document +from llama_stack_client import Agent, AgentEventLogger, Document from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack.apis.agents.agents import ( @@ -92,7 +90,7 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config): session_id=session_id, ) - logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None] + logs = [str(log) for log in AgentEventLogger().log(simple_hello) if log is not None] logs_str = "".join(logs) assert "hello" in logs_str.lower() @@ -111,7 +109,7 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config): session_id=session_id, ) - logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None] + logs = [str(log) for log in AgentEventLogger().log(bomb_response) if log is not None] logs_str = "".join(logs) assert "I can't" in logs_str @@ -192,7 +190,7 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent session_id=session_id, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs = [str(log) for log in AgentEventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "tool_execution>" in logs_str @@ -221,7 +219,7 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a ], session_id=session_id, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs = [str(log) for log in AgentEventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "541" in logs_str @@ -262,7 +260,7 @@ def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inferen session_id=session_id, documents=input.get("documents", None), ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs = [str(log) for log in AgentEventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "Tool:code_interpreter" in logs_str @@ -287,7 +285,7 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config): session_id=session_id, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs = [str(log) for log in AgentEventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "-100" in logs_str assert "get_boiling_point" in logs_str