feat(api): simplify client imports (#1687)

# What does this PR do?
closes #1554 

## Test Plan
test_agents.py
This commit is contained in:
ehhuang 2025-03-20 10:15:49 -07:00 committed by GitHub
parent 515c16e352
commit ea6a4a14ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 40 additions and 58 deletions

View file

@ -1203,7 +1203,7 @@
}
],
"source": [
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
"from llama_stack_client import InferenceEventLogger\n",
"\n",
"message = {\"role\": \"user\", \"content\": \"Write me a sonnet about llama\"}\n",
"print(f'User> {message[\"content\"]}', \"green\")\n",
@ -1215,7 +1215,7 @@
")\n",
"\n",
"# Print the tokens while they are received\n",
"for log in EventLogger().log(response):\n",
"for log in InferenceEventLogger().log(response):\n",
" log.print()\n"
]
},
@ -1632,8 +1632,7 @@
}
],
"source": [
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client import Agent, AgentEventLogger\n",
"from termcolor import cprint\n",
"\n",
"agent = Agent(\n",
@ -1659,7 +1658,7 @@
" ],\n",
" session_id=session_id,\n",
" )\n",
" for log in EventLogger().log(response):\n",
" for log in AgentEventLogger().log(response):\n",
" log.print()\n"
]
},
@ -1808,14 +1807,12 @@
],
"source": [
"import uuid\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client import Agent, AgentEventLogger, RAGDocument\n",
"from termcolor import cprint\n",
"from llama_stack_client.types import Document\n",
"\n",
"urls = [\"chat.rst\", \"llama3.rst\", \"memory_optimizations.rst\", \"lora_finetune.rst\"]\n",
"documents = [\n",
" Document(\n",
" RAGDocument(\n",
" document_id=f\"num-{i}\",\n",
" content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
" mime_type=\"text/plain\",\n",
@ -1858,7 +1855,7 @@
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" session_id=session_id,\n",
" )\n",
" for log in EventLogger().log(response):\n",
" for log in AgentEventLogger().log(response):\n",
" log.print()"
]
},
@ -1969,7 +1966,7 @@
}
],
"source": [
"from llama_stack_client.types.agents.turn_create_params import Document\n",
"from llama_stack_client import Document\n",
"\n",
"codex_agent = Agent(\n",
" client, \n",
@ -2891,8 +2888,7 @@
],
"source": [
"# NBVAL_SKIP\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client import Agent, AgentEventLogger\n",
"from termcolor import cprint\n",
"\n",
"agent = Agent(\n",
@ -2918,7 +2914,7 @@
" ],\n",
" session_id=session_id,\n",
" )\n",
" for log in EventLogger().log(response):\n",
" for log in AgentEventLogger().log(response):\n",
" log.print()\n"
]
},
@ -2993,8 +2989,7 @@
}
],
"source": [
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client import Agent, AgentEventLogger\n",
"\n",
"agent = Agent(\n",
" client, \n",
@ -3021,7 +3016,7 @@
" session_id=session_id,\n",
" )\n",
"\n",
" for log in EventLogger().log(response):\n",
" for log in AgentEventLogger().log(response):\n",
" log.print()\n"
]
},