skip code interp

This commit is contained in:
Xi Yan 2025-03-28 12:48:56 -07:00
parent 37b6da37ba
commit 876693e710

View file

@ -8,15 +8,13 @@ from typing import Any, Dict
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from llama_stack_client import Agent, AgentEventLogger, Document
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack.apis.agents.agents import ( from llama_stack.apis.agents.agents import (
AgentConfig as Server__AgentConfig, AgentConfig as Server__AgentConfig,
)
from llama_stack.apis.agents.agents import (
ToolChoice, ToolChoice,
) )
from llama_stack_client import Agent, AgentEventLogger, Document
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int: def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
@ -36,7 +34,9 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
return -1 return -1
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]: def get_boiling_point_with_metadata(
liquid_name: str, celcius: bool = True
) -> Dict[str, Any]:
""" """
Returns the boiling point of a liquid in Celcius or Fahrenheit Returns the boiling point of a liquid in Celcius or Fahrenheit
@ -56,7 +56,10 @@ def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> D
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def agent_config(llama_stack_client_with_mocked_inference, text_model_id): def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()] available_shields = [
shield.identifier
for shield in llama_stack_client_with_mocked_inference.shields.list()
]
available_shields = available_shields[:1] available_shields = available_shields[:1]
agent_config = dict( agent_config = dict(
model=text_model_id, model=text_model_id,
@ -109,7 +112,9 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
session_id=session_id, session_id=session_id,
) )
logs = [str(log) for log in AgentEventLogger().log(bomb_response) if log is not None] logs = [
str(log) for log in AgentEventLogger().log(bomb_response) if log is not None
]
logs_str = "".join(logs) logs_str = "".join(logs)
assert "I can't" in logs_str assert "I can't" in logs_str
@ -170,7 +175,9 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
Server__AgentConfig(**agent_config) Server__AgentConfig(**agent_config)
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config): def test_builtin_tool_web_search(
llama_stack_client_with_mocked_inference, agent_config
):
agent_config = { agent_config = {
**agent_config, **agent_config,
"instructions": "You are a helpful assistant that can use web search to answer questions.", "instructions": "You are a helpful assistant that can use web search to answer questions.",
@ -201,7 +208,9 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
assert found_tool_execution assert found_tool_execution
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config): def test_builtin_tool_code_execution(
llama_stack_client_with_mocked_inference, agent_config
):
agent_config = { agent_config = {
**agent_config, **agent_config,
"tools": [ "tools": [
@ -230,7 +239,9 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
# This test must be run in an environment where `bwrap` is available. If you are running against a # This test must be run in an environment where `bwrap` is available. If you are running against a
# server, this means the _server_ must have `bwrap` available. If you are using library client, then # server, this means the _server_ must have `bwrap` available. If you are using library client, then
# you must have `bwrap` available in test's environment. # you must have `bwrap` available in test's environment.
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config): def test_code_interpreter_for_attachments(
llama_stack_client_with_mocked_inference, agent_config
):
agent_config = { agent_config = {
**agent_config, **agent_config,
"tools": [ "tools": [
@ -292,7 +303,9 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
assert "get_boiling_point" in logs_str assert "get_boiling_point" in logs_str
def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, agent_config): def test_custom_tool_infinite_loop(
llama_stack_client_with_mocked_inference, agent_config
):
client_tool = get_boiling_point client_tool = get_boiling_point
agent_config = { agent_config = {
**agent_config, **agent_config,
@ -315,7 +328,9 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
stream=False, stream=False,
) )
num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps]) num_tool_calls = sum(
[1 if step.step_type == "tool_execution" else 0 for step in response.steps]
)
assert num_tool_calls <= 5 assert num_tool_calls <= 5
@ -327,18 +342,25 @@ def test_tool_choice_required(llama_stack_client_with_mocked_inference, agent_co
def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config): def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config):
tool_execution_steps = run_agent_with_tool_choice(llama_stack_client_with_mocked_inference, agent_config, "none") tool_execution_steps = run_agent_with_tool_choice(
llama_stack_client_with_mocked_inference, agent_config, "none"
)
assert len(tool_execution_steps) == 0 assert len(tool_execution_steps) == 0
def test_tool_choice_get_boiling_point(llama_stack_client_with_mocked_inference, agent_config): def test_tool_choice_get_boiling_point(
llama_stack_client_with_mocked_inference, agent_config
):
if "llama" not in agent_config["model"].lower(): if "llama" not in agent_config["model"].lower():
pytest.xfail("NotImplemented for non-llama models") pytest.xfail("NotImplemented for non-llama models")
tool_execution_steps = run_agent_with_tool_choice( tool_execution_steps = run_agent_with_tool_choice(
llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point" llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point"
) )
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point" assert (
len(tool_execution_steps) >= 1
and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
)
def run_agent_with_tool_choice(client, agent_config, tool_choice): def run_agent_with_tool_choice(client, agent_config, tool_choice):
@ -368,8 +390,12 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
return [step for step in response.steps if step.step_type == "tool_execution"] return [step for step in response.steps if step.step_type == "tool_execution"]
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]) @pytest.mark.parametrize(
def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name): "rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]
)
def test_rag_agent(
llama_stack_client_with_mocked_inference, agent_config, rag_tool_name
):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [ documents = [
Document( Document(
@ -418,17 +444,22 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
stream=False, stream=False,
) )
# rag is called # rag is called
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") tool_execution_step = next(
step for step in response.steps if step.step_type == "tool_execution"
)
assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search" assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
# document ids are present in metadata # document ids are present in metadata
assert all( assert all(
doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"] doc_id.startswith("num-")
for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
) )
if expected_kw: if expected_kw:
assert expected_kw in response.output_message.content.lower() assert expected_kw in response.output_message.content.lower()
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config): def test_rag_agent_with_attachments(
llama_stack_client_with_mocked_inference, agent_config
):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [ documents = [
Document( Document(
@ -520,19 +551,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
], ],
} }
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
inflation_doc = Document(
document_id="test_csv",
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
mime_type="text/csv",
metadata={},
)
user_prompts = [ user_prompts = [
(
"Here is a csv file, can you describe it?",
[inflation_doc],
"code_interpreter",
"",
),
( (
"when was Perplexity the company founded?", "when was Perplexity the company founded?",
[], [],
@ -555,7 +574,9 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
documents=docs, documents=docs,
stream=False, stream=False,
) )
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") tool_execution_step = next(
step for step in response.steps if step.step_type == "tool_execution"
)
assert tool_execution_step.tool_calls[0].tool_name == tool_name assert tool_execution_step.tool_calls[0].tool_name == tool_name
if expected_kw: if expected_kw:
assert expected_kw in response.output_message.content.lower() assert expected_kw in response.output_message.content.lower()
@ -565,7 +586,9 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
"client_tools", "client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)], [(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
) )
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools): def test_create_turn_response(
llama_stack_client_with_mocked_inference, agent_config, client_tools
):
client_tool, expects_metadata = client_tools client_tool, expects_metadata = client_tools
agent_config = { agent_config = {
**agent_config, **agent_config,