mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat(api): simplify client imports (#1687)
# What does this PR do? closes #1554 ## Test Plan test_agents.py
This commit is contained in:
parent
515c16e352
commit
ea6a4a14ce
11 changed files with 40 additions and 58 deletions
|
@ -1203,7 +1203,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
|
||||
"from llama_stack_client import InferenceEventLogger\n",
|
||||
"\n",
|
||||
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
|
||||
"print(f'User> {message[\"content\"]}', \"green\")\n",
|
||||
|
@ -1215,7 +1215,7 @@
|
|||
")\n",
|
||||
"\n",
|
||||
"# Print the tokens while they are received\n",
|
||||
"for log in EventLogger().log(response):\n",
|
||||
"for log in InferenceEventLogger().log(response):\n",
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
|
@ -1632,8 +1632,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client import Agent, AgentEventLogger\n",
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"agent = Agent(\n",
|
||||
|
@ -1659,7 +1658,7 @@
|
|||
" ],\n",
|
||||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" for log in AgentEventLogger().log(response):\n",
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
|
@ -1808,14 +1807,12 @@
|
|||
],
|
||||
"source": [
|
||||
"import uuid\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client import Agent, AgentEventLogger, RAGDocument\n",
|
||||
"from termcolor import cprint\n",
|
||||
"from llama_stack_client.types import Document\n",
|
||||
"\n",
|
||||
"urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n",
|
||||
"documents = [\n",
|
||||
" Document(\n",
|
||||
" RAGDocument(\n",
|
||||
" document_id=f\"num-{i}\",\n",
|
||||
" content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
|
||||
" mime_type=\"text/plain\",\n",
|
||||
|
@ -1858,7 +1855,7 @@
|
|||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" for log in AgentEventLogger().log(response):\n",
|
||||
" log.print()"
|
||||
]
|
||||
},
|
||||
|
@ -1969,7 +1966,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.types.agents.turn_create_params import Document\n",
|
||||
"from llama_stack_client import Document\n",
|
||||
"\n",
|
||||
"codex_agent = Agent(\n",
|
||||
" client, \n",
|
||||
|
@ -2891,8 +2888,7 @@
|
|||
],
|
||||
"source": [
|
||||
"# NBVAL_SKIP\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client import Agent, AgentEventLogger\n",
|
||||
"from termcolor import cprint\n",
|
||||
"\n",
|
||||
"agent = Agent(\n",
|
||||
|
@ -2918,7 +2914,7 @@
|
|||
" ],\n",
|
||||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" for log in AgentEventLogger().log(response):\n",
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
|
@ -2993,8 +2989,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
|
||||
"from llama_stack_client import Agent, AgentEventLogger\n",
|
||||
"\n",
|
||||
"agent = Agent(\n",
|
||||
" client, \n",
|
||||
|
@ -3021,7 +3016,7 @@
|
|||
" session_id=session_id,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" for log in EventLogger().log(response):\n",
|
||||
" for log in AgentEventLogger().log(response):\n",
|
||||
" log.print()\n"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -47,9 +47,8 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack_client import LlamaStackClient, Agent\n",
|
||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from rich.pretty import pprint\n",
|
||||
"import json\n",
|
||||
"import uuid\n",
|
||||
|
|
|
@ -34,10 +34,8 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from llama_stack_client import LlamaStackClient\n",
|
||||
"from llama_stack_client import LlamaStackClient, Agent\n",
|
||||
"from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n",
|
||||
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
|
||||
"from llama_stack_client.lib.agents.agent import Agent\n",
|
||||
"from rich.pretty import pprint\n",
|
||||
"import json\n",
|
||||
"import uuid\n",
|
||||
|
|
|
@ -14,7 +14,7 @@ Agents are configured using the `AgentConfig` class, which includes:
|
|||
- **Safety Shields**: Guardrails to ensure responsible AI behavior
|
||||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client import Agent
|
||||
|
||||
|
||||
# Create the agent
|
||||
|
@ -44,14 +44,14 @@ Each interaction with an agent is called a "turn" and consists of:
|
|||
- **Output Message**: The agent's response
|
||||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client import AgentEventLogger
|
||||
|
||||
# Create a turn with streaming response
|
||||
turn_response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[{"role": "user", "content": "Tell me about Llama models"}],
|
||||
)
|
||||
for log in EventLogger().log(turn_response):
|
||||
for log in AgentEventLogger().log(turn_response):
|
||||
log.print()
|
||||
```
|
||||
### Non-Streaming
|
||||
|
|
|
@ -67,9 +67,7 @@ sequenceDiagram
|
|||
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
||||
|
||||
```python
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
|
||||
from rich.pretty import pprint
|
||||
|
||||
# Replace host and port
|
||||
|
@ -113,7 +111,7 @@ response = agent.create_turn(
|
|||
)
|
||||
|
||||
# Monitor each step of execution
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
|
||||
# Using non-streaming API, the response contains input, steps, and output.
|
||||
|
|
|
@ -23,9 +23,7 @@ In this example, we will show you how to:
|
|||
|
||||
##### Building a Search Agent
|
||||
```python
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
|
||||
|
||||
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||
|
||||
|
@ -54,7 +52,7 @@ for prompt in user_prompts:
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
```
|
||||
|
||||
|
|
|
@ -55,11 +55,11 @@ chunks_response = client.vector_io.query(
|
|||
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces.
|
||||
|
||||
```python
|
||||
from llama_stack_client.types import Document
|
||||
from llama_stack_client import RAGDocument
|
||||
|
||||
urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
RAGDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
|
@ -86,7 +86,7 @@ results = client.tool_runtime.rag_tool.query(
|
|||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client import Agent
|
||||
|
||||
# Create agent with memory
|
||||
agent = Agent(
|
||||
|
@ -140,9 +140,9 @@ response = agent.create_turn(
|
|||
|
||||
You can print the response with below.
|
||||
```python
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client import AgentEventLogger
|
||||
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
```
|
||||
|
||||
|
|
|
@ -189,7 +189,7 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
|||
## Simple Example: Using an Agent with the Code-Interpreter Tool
|
||||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client import Agent
|
||||
|
||||
# Instantiate the AI agent with the given configuration
|
||||
agent = Agent(
|
||||
|
|
|
@ -197,9 +197,7 @@ import os
|
|||
import uuid
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types import Document
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||
|
||||
|
||||
def create_http_client():
|
||||
|
@ -225,7 +223,7 @@ client = (
|
|||
# Documents to be used for RAG
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
RAGDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
|
@ -284,7 +282,7 @@ for prompt in user_prompts:
|
|||
messages=[{"role": "user", "content": prompt}],
|
||||
session_id=session_id,
|
||||
)
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
```
|
||||
|
||||
|
|
|
@ -5,9 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.shared.document import Document
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||
|
@ -35,7 +33,7 @@ def rag_chat_page():
|
|||
)
|
||||
if st.button("Create Vector Database"):
|
||||
documents = [
|
||||
Document(
|
||||
RAGDocument(
|
||||
document_id=uploaded_file.name,
|
||||
content=data_url_from_file(uploaded_file),
|
||||
)
|
||||
|
@ -167,7 +165,7 @@ def rag_chat_page():
|
|||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
if log.role == "tool_execution":
|
||||
retrieval_response += log.content.replace("====", "").strip()
|
||||
|
|
|
@ -8,9 +8,7 @@ from typing import Any, Dict
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agents.turn_create_params import Document
|
||||
from llama_stack_client import Agent, AgentEventLogger, Document
|
||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||
|
||||
from llama_stack.apis.agents.agents import (
|
||||
|
@ -92,7 +90,7 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(simple_hello) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
|
||||
assert "hello" in logs_str.lower()
|
||||
|
@ -111,7 +109,7 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(bomb_response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "I can't" in logs_str
|
||||
|
||||
|
@ -192,7 +190,7 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
|
||||
assert "tool_execution>" in logs_str
|
||||
|
@ -221,7 +219,7 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
|
|||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
|
||||
assert "541" in logs_str
|
||||
|
@ -262,7 +260,7 @@ def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inferen
|
|||
session_id=session_id,
|
||||
documents=input.get("documents", None),
|
||||
)
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "Tool:code_interpreter" in logs_str
|
||||
|
||||
|
@ -287,7 +285,7 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "-100" in logs_str
|
||||
assert "get_boiling_point" in logs_str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue