mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
skip code interp
This commit is contained in:
parent
37b6da37ba
commit
876693e710
1 changed files with 57 additions and 34 deletions
|
@ -8,15 +8,13 @@ from typing import Any, Dict
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client import Agent, AgentEventLogger, Document
|
|
||||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
|
||||||
|
|
||||||
from llama_stack.apis.agents.agents import (
|
from llama_stack.apis.agents.agents import (
|
||||||
AgentConfig as Server__AgentConfig,
|
AgentConfig as Server__AgentConfig,
|
||||||
)
|
|
||||||
from llama_stack.apis.agents.agents import (
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
)
|
)
|
||||||
|
from llama_stack_client import Agent, AgentEventLogger, Document
|
||||||
|
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||||
|
|
||||||
|
|
||||||
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||||
|
@ -36,7 +34,9 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
|
||||||
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]:
|
def get_boiling_point_with_metadata(
|
||||||
|
liquid_name: str, celcius: bool = True
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Returns the boiling point of a liquid in Celcius or Fahrenheit
|
Returns the boiling point of a liquid in Celcius or Fahrenheit
|
||||||
|
|
||||||
|
@ -56,7 +56,10 @@ def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> D
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
||||||
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
|
available_shields = [
|
||||||
|
shield.identifier
|
||||||
|
for shield in llama_stack_client_with_mocked_inference.shields.list()
|
||||||
|
]
|
||||||
available_shields = available_shields[:1]
|
available_shields = available_shields[:1]
|
||||||
agent_config = dict(
|
agent_config = dict(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
|
@ -109,7 +112,9 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logs = [str(log) for log in AgentEventLogger().log(bomb_response) if log is not None]
|
logs = [
|
||||||
|
str(log) for log in AgentEventLogger().log(bomb_response) if log is not None
|
||||||
|
]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
assert "I can't" in logs_str
|
assert "I can't" in logs_str
|
||||||
|
|
||||||
|
@ -170,7 +175,9 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
Server__AgentConfig(**agent_config)
|
Server__AgentConfig(**agent_config)
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
|
def test_builtin_tool_web_search(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config
|
||||||
|
):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"instructions": "You are a helpful assistant that can use web search to answer questions.",
|
"instructions": "You are a helpful assistant that can use web search to answer questions.",
|
||||||
|
@ -201,7 +208,9 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
|
||||||
assert found_tool_execution
|
assert found_tool_execution
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
|
def test_builtin_tool_code_execution(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config
|
||||||
|
):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tools": [
|
"tools": [
|
||||||
|
@ -230,7 +239,9 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
|
||||||
# This test must be run in an environment where `bwrap` is available. If you are running against a
|
# This test must be run in an environment where `bwrap` is available. If you are running against a
|
||||||
# server, this means the _server_ must have `bwrap` available. If you are using library client, then
|
# server, this means the _server_ must have `bwrap` available. If you are using library client, then
|
||||||
# you must have `bwrap` available in test's environment.
|
# you must have `bwrap` available in test's environment.
|
||||||
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
def test_code_interpreter_for_attachments(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config
|
||||||
|
):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tools": [
|
"tools": [
|
||||||
|
@ -292,7 +303,9 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
assert "get_boiling_point" in logs_str
|
assert "get_boiling_point" in logs_str
|
||||||
|
|
||||||
|
|
||||||
def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, agent_config):
|
def test_custom_tool_infinite_loop(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config
|
||||||
|
):
|
||||||
client_tool = get_boiling_point
|
client_tool = get_boiling_point
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
|
@ -315,7 +328,9 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_tool_calls = sum([1 if step.step_type == "tool_execution" else 0 for step in response.steps])
|
num_tool_calls = sum(
|
||||||
|
[1 if step.step_type == "tool_execution" else 0 for step in response.steps]
|
||||||
|
)
|
||||||
assert num_tool_calls <= 5
|
assert num_tool_calls <= 5
|
||||||
|
|
||||||
|
|
||||||
|
@ -327,18 +342,25 @@ def test_tool_choice_required(llama_stack_client_with_mocked_inference, agent_co
|
||||||
|
|
||||||
|
|
||||||
def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config):
|
def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
tool_execution_steps = run_agent_with_tool_choice(llama_stack_client_with_mocked_inference, agent_config, "none")
|
tool_execution_steps = run_agent_with_tool_choice(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config, "none"
|
||||||
|
)
|
||||||
assert len(tool_execution_steps) == 0
|
assert len(tool_execution_steps) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_tool_choice_get_boiling_point(llama_stack_client_with_mocked_inference, agent_config):
|
def test_tool_choice_get_boiling_point(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config
|
||||||
|
):
|
||||||
if "llama" not in agent_config["model"].lower():
|
if "llama" not in agent_config["model"].lower():
|
||||||
pytest.xfail("NotImplemented for non-llama models")
|
pytest.xfail("NotImplemented for non-llama models")
|
||||||
|
|
||||||
tool_execution_steps = run_agent_with_tool_choice(
|
tool_execution_steps = run_agent_with_tool_choice(
|
||||||
llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point"
|
llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point"
|
||||||
)
|
)
|
||||||
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
|
assert (
|
||||||
|
len(tool_execution_steps) >= 1
|
||||||
|
and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_agent_with_tool_choice(client, agent_config, tool_choice):
|
def run_agent_with_tool_choice(client, agent_config, tool_choice):
|
||||||
|
@ -368,8 +390,12 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
|
||||||
return [step for step in response.steps if step.step_type == "tool_execution"]
|
return [step for step in response.steps if step.step_type == "tool_execution"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
|
@pytest.mark.parametrize(
|
||||||
def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name):
|
"rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]
|
||||||
|
)
|
||||||
|
def test_rag_agent(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config, rag_tool_name
|
||||||
|
):
|
||||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
|
@ -418,17 +444,22 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
# rag is called
|
# rag is called
|
||||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
tool_execution_step = next(
|
||||||
|
step for step in response.steps if step.step_type == "tool_execution"
|
||||||
|
)
|
||||||
assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
|
assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
|
||||||
# document ids are present in metadata
|
# document ids are present in metadata
|
||||||
assert all(
|
assert all(
|
||||||
doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
|
doc_id.startswith("num-")
|
||||||
|
for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
|
||||||
)
|
)
|
||||||
if expected_kw:
|
if expected_kw:
|
||||||
assert expected_kw in response.output_message.content.lower()
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
def test_rag_agent_with_attachments(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config
|
||||||
|
):
|
||||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
|
@ -520,19 +551,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
inflation_doc = Document(
|
|
||||||
document_id="test_csv",
|
|
||||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
|
||||||
mime_type="text/csv",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
(
|
|
||||||
"Here is a csv file, can you describe it?",
|
|
||||||
[inflation_doc],
|
|
||||||
"code_interpreter",
|
|
||||||
"",
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
"when was Perplexity the company founded?",
|
"when was Perplexity the company founded?",
|
||||||
[],
|
[],
|
||||||
|
@ -555,7 +574,9 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
||||||
documents=docs,
|
documents=docs,
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
tool_execution_step = next(
|
||||||
|
step for step in response.steps if step.step_type == "tool_execution"
|
||||||
|
)
|
||||||
assert tool_execution_step.tool_calls[0].tool_name == tool_name
|
assert tool_execution_step.tool_calls[0].tool_name == tool_name
|
||||||
if expected_kw:
|
if expected_kw:
|
||||||
assert expected_kw in response.output_message.content.lower()
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
@ -565,7 +586,9 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
||||||
"client_tools",
|
"client_tools",
|
||||||
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
||||||
)
|
)
|
||||||
def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_config, client_tools):
|
def test_create_turn_response(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config, client_tools
|
||||||
|
):
|
||||||
client_tool, expects_metadata = client_tools
|
client_tool, expects_metadata = client_tools
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue