feat(api): simplify client imports (#1687)

# What does this PR do?
closes #1554 

## Test Plan
test_agents.py
This commit is contained in:
ehhuang 2025-03-20 10:15:49 -07:00 committed by GitHub
parent 515c16e352
commit ea6a4a14ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 40 additions and 58 deletions

View file

@ -8,9 +8,7 @@ from typing import Any, Dict
from uuid import uuid4
import pytest
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agents.turn_create_params import Document
from llama_stack_client import Agent, AgentEventLogger, Document
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack.apis.agents.agents import (
@ -92,7 +90,7 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None]
logs = [str(log) for log in AgentEventLogger().log(simple_hello) if log is not None]
logs_str = "".join(logs)
assert "hello" in logs_str.lower()
@ -111,7 +109,7 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None]
logs = [str(log) for log in AgentEventLogger().log(bomb_response) if log is not None]
logs_str = "".join(logs)
assert "I can't" in logs_str
@ -192,7 +190,7 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "tool_execution>" in logs_str
@ -221,7 +219,7 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
],
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "541" in logs_str
@ -262,7 +260,7 @@ def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inferen
session_id=session_id,
documents=input.get("documents", None),
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:code_interpreter" in logs_str
@ -287,7 +285,7 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "-100" in logs_str
assert "get_boiling_point" in logs_str