fix failing code interpreter tests

This commit is contained in:
Dinesh Yeduguru 2025-01-07 22:13:33 -08:00
parent 82395ba654
commit db2ec110a1
2 changed files with 19 additions and 32 deletions

View file

@ -78,7 +78,6 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_TOOL_GROUP_ID = "builtin::memory"
MEMORY_QUERY_TOOL = "query_memory"
CODE_INTERPRETER_TOOL = "code_interpreter"
WEB_SEARCH_TOOL = "web_search"
@ -787,7 +786,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None)
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")

View file

@ -275,14 +275,7 @@ def test_custom_tool(llama_stack_client, agent_config):
def test_rag_agent(llama_stack_client, agent_config):
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
documents = [
Document(
document_id=f"num-{i}",
@ -292,15 +285,7 @@ def test_rag_agent(llama_stack_client, agent_config):
)
for i, url in enumerate(urls)
]
memory_bank_id = "test-memory-bank"
agent_config["toolgroups"].append(
dict(
name="builtin::memory",
args={"memory_bank_ids": [memory_bank_id]},
)
)
agent = Agent(llama_stack_client, agent_config)
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
@ -314,25 +299,28 @@ def test_rag_agent(llama_stack_client, agent_config):
bank_id=memory_bank_id,
documents=documents,
)
session_id = agent.create_session(f"test-session-{uuid4()}")
agent_config = {
**agent_config,
"toolgroups": [
dict(
name="builtin::memory",
args={
"memory_bank_ids": [memory_bank_id],
},
)
],
}
rag_agent = Agent(llama_stack_client, agent_config)
session_id = rag_agent.create_session("test-session")
user_prompts = [
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
"Was anything related to 'Llama3' discussed, if so what?",
"Tell me how to use LoRA",
"What are the top 5 topics that were explained? Only list succinct bullet points.",
]
for prompt in user_prompts:
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
print(f"User> {prompt}")
response = rag_agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:query_memory" in logs_str