mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
fix agent server tests
This commit is contained in:
parent
439f52b067
commit
18d9937500
4 changed files with 41 additions and 63 deletions
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import parse_obj_as
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
@ -39,7 +39,6 @@ from llama_stack.distribution.datatypes import (
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
RoutedProtocol,
|
RoutedProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
@ -361,7 +360,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
memory_bank_data["embedding_dimension"] = model.metadata[
|
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||||
"embedding_dimension"
|
"embedding_dimension"
|
||||||
]
|
]
|
||||||
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
|
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
|
||||||
await self.register_object(memory_bank)
|
await self.register_object(memory_bank)
|
||||||
return memory_bank
|
return memory_bank
|
||||||
|
|
||||||
|
@ -525,7 +524,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
# parse tool group to the type if dict
|
# parse tool group to the type if dict
|
||||||
tool_group = parse_obj_as(ToolGroupDef, tool_group)
|
tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group)
|
||||||
if isinstance(tool_group, MCPToolGroupDef):
|
if isinstance(tool_group, MCPToolGroupDef):
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||||
tool_group
|
tool_group
|
||||||
|
|
|
@ -19,7 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
id="meta_reference",
|
id="meta_reference",
|
||||||
marks=pytest.mark.meta_reference,
|
marks=pytest.mark.meta_reference,
|
||||||
|
@ -30,7 +30,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
id="ollama",
|
id="ollama",
|
||||||
marks=pytest.mark.ollama,
|
marks=pytest.mark.ollama,
|
||||||
|
@ -42,7 +42,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
# make this work with Weaviate which is what the together distro supports
|
# make this work with Weaviate which is what the together distro supports
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
id="together",
|
id="together",
|
||||||
marks=pytest.mark.together,
|
marks=pytest.mark.together,
|
||||||
|
@ -53,7 +53,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "llama_guard",
|
"safety": "llama_guard",
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
"agents": "meta_reference",
|
"agents": "meta_reference",
|
||||||
"tool_runtime": "memory",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
id="fireworks",
|
id="fireworks",
|
||||||
marks=pytest.mark.fireworks,
|
marks=pytest.mark.fireworks,
|
||||||
|
@ -64,7 +64,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
"safety": "remote",
|
"safety": "remote",
|
||||||
"memory": "remote",
|
"memory": "remote",
|
||||||
"agents": "remote",
|
"agents": "remote",
|
||||||
"tool_runtime": "memory",
|
"tool_runtime": "memory_and_search",
|
||||||
},
|
},
|
||||||
id="remote",
|
id="remote",
|
||||||
marks=pytest.mark.remote,
|
marks=pytest.mark.remote,
|
||||||
|
|
|
@ -64,7 +64,7 @@ def agents_meta_reference() -> ProviderFixture:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def tool_runtime_memory() -> ProviderFixture:
|
def tool_runtime_memory_and_search() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
|
@ -72,31 +72,19 @@ def tool_runtime_memory() -> ProviderFixture:
|
||||||
provider_type="inline::memory-runtime",
|
provider_type="inline::memory-runtime",
|
||||||
config={},
|
config={},
|
||||||
),
|
),
|
||||||
Provider(
|
|
||||||
provider_id="brave-search",
|
|
||||||
provider_type="inline::brave-search",
|
|
||||||
config={
|
|
||||||
"api_key": os.environ["BRAVE_SEARCH_API_KEY"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="tavily-search",
|
provider_id="tavily-search",
|
||||||
provider_type="inline::tavily-search",
|
provider_type="remote::tavily-search",
|
||||||
config={
|
config={
|
||||||
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
Provider(
|
|
||||||
provider_id="code-interpreter",
|
|
||||||
provider_type="inline::code-interpreter",
|
|
||||||
config={},
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
AGENTS_FIXTURES = ["meta_reference", "remote"]
|
AGENTS_FIXTURES = ["meta_reference", "remote"]
|
||||||
TOOL_RUNTIME_FIXTURES = ["memory"]
|
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
@ -173,43 +161,25 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
name="memory",
|
name="memory",
|
||||||
description="memory",
|
description="memory",
|
||||||
parameters=[
|
parameters=[
|
||||||
ToolParameter(
|
|
||||||
name="session_id",
|
|
||||||
description="session id",
|
|
||||||
parameter_type="string",
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
ToolParameter(
|
ToolParameter(
|
||||||
name="input_messages",
|
name="input_messages",
|
||||||
description="messages",
|
description="messages",
|
||||||
parameter_type="list",
|
parameter_type="list",
|
||||||
required=True,
|
required=True,
|
||||||
),
|
),
|
||||||
ToolParameter(
|
|
||||||
name="attachments",
|
|
||||||
description="attachments",
|
|
||||||
parameter_type="list",
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
metadata={},
|
metadata={
|
||||||
|
"config": {
|
||||||
|
"memory_bank_configs": [
|
||||||
|
{"bank_id": "test_bank", "type": "vector"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
provider_id="memory-runtime",
|
provider_id="memory-runtime",
|
||||||
),
|
),
|
||||||
ToolGroupInput(
|
|
||||||
tool_group_id="code_interpreter_group",
|
|
||||||
tool_group=UserDefinedToolGroupDef(
|
|
||||||
tools=[
|
|
||||||
BuiltInToolDef(
|
|
||||||
built_in_type=BuiltinTool.code_interpreter,
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
provider_id="code-interpreter",
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
|
|
|
@ -8,19 +8,13 @@ import os
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgentTool,
|
|
||||||
AgentTurnResponseEventType,
|
AgentTurnResponseEventType,
|
||||||
AgentTurnResponseStepCompletePayload,
|
AgentTurnResponseStepCompletePayload,
|
||||||
AgentTurnResponseStreamChunk,
|
AgentTurnResponseStreamChunk,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
Attachment,
|
|
||||||
MemoryToolDefinition,
|
|
||||||
SearchEngineType,
|
|
||||||
SearchToolDefinition,
|
|
||||||
ShieldCallStep,
|
ShieldCallStep,
|
||||||
StepType,
|
StepType,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
@ -228,7 +222,7 @@ class TestAgents:
|
||||||
check_turn_complete_event(turn_response, session_id, sample_messages)
|
check_turn_complete_event(turn_response, session_id, sample_messages)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rag_agent_as_attachments(
|
async def test_rag_agent(
|
||||||
self,
|
self,
|
||||||
agents_stack,
|
agents_stack,
|
||||||
attachment_message,
|
attachment_message,
|
||||||
|
@ -236,6 +230,8 @@ class TestAgents:
|
||||||
common_params,
|
common_params,
|
||||||
):
|
):
|
||||||
agents_impl = agents_stack.impls[Api.agents]
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
memory_banks_impl = agents_stack.impls[Api.memory_banks]
|
||||||
|
memory_impl = agents_stack.impls[Api.memory]
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
"chat.rst",
|
"chat.rst",
|
||||||
|
@ -244,14 +240,28 @@ class TestAgents:
|
||||||
"qat_finetune.rst",
|
"qat_finetune.rst",
|
||||||
"lora_finetune.rst",
|
"lora_finetune.rst",
|
||||||
]
|
]
|
||||||
|
documents = [
|
||||||
attachments = [
|
MemoryBankDocument(
|
||||||
Attachment(
|
document_id=f"num-{i}",
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
|
metadata={},
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
|
await memory_banks_impl.register_memory_bank(
|
||||||
|
memory_bank_id="test_bank",
|
||||||
|
params=VectorMemoryBankParams(
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
),
|
||||||
|
provider_id="faiss",
|
||||||
|
)
|
||||||
|
memory_impl.insert_documents(
|
||||||
|
bank_id="test_bank",
|
||||||
|
documents=documents,
|
||||||
|
)
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
**{
|
**{
|
||||||
|
@ -266,7 +276,6 @@ class TestAgents:
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=attachment_message,
|
messages=attachment_message,
|
||||||
attachments=attachments,
|
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
turn_response = [
|
turn_response = [
|
||||||
|
@ -290,11 +299,11 @@ class TestAgents:
|
||||||
assert len(turn_response) > 0
|
assert len(turn_response) > 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_agent_turn_with_brave_search(
|
async def test_create_agent_turn_with_tavily_search(
|
||||||
self, agents_stack, search_query_messages, common_params
|
self, agents_stack, search_query_messages, common_params
|
||||||
):
|
):
|
||||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
await create_agent_turn_with_search_tool(
|
await create_agent_turn_with_search_tool(
|
||||||
agents_stack,
|
agents_stack,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue