mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
fix failing code interpreter tests
This commit is contained in:
parent
82395ba654
commit
db2ec110a1
2 changed files with 19 additions and 32 deletions
|
@ -78,7 +78,6 @@ def make_random_string(length: int = 8):
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
MEMORY_TOOL_GROUP_ID = "builtin::memory"
|
MEMORY_TOOL_GROUP_ID = "builtin::memory"
|
||||||
MEMORY_QUERY_TOOL = "query_memory"
|
MEMORY_QUERY_TOOL = "query_memory"
|
||||||
CODE_INTERPRETER_TOOL = "code_interpreter"
|
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
|
|
||||||
|
|
||||||
|
@ -787,7 +786,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_defs: Dict[str, ToolDefinition],
|
tool_defs: Dict[str, ToolDefinition],
|
||||||
) -> None:
|
) -> None:
|
||||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
||||||
code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None)
|
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
||||||
content_items = []
|
content_items = []
|
||||||
url_items = []
|
url_items = []
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
|
|
|
@ -275,14 +275,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
|
|
||||||
|
|
||||||
def test_rag_agent(llama_stack_client, agent_config):
|
def test_rag_agent(llama_stack_client, agent_config):
|
||||||
urls = [
|
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
|
||||||
"memory_optimizations.rst",
|
|
||||||
"chat.rst",
|
|
||||||
"llama3.rst",
|
|
||||||
"datasets.rst",
|
|
||||||
"qat_finetune.rst",
|
|
||||||
"lora_finetune.rst",
|
|
||||||
]
|
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
document_id=f"num-{i}",
|
document_id=f"num-{i}",
|
||||||
|
@ -292,15 +285,7 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
|
|
||||||
memory_bank_id = "test-memory-bank"
|
memory_bank_id = "test-memory-bank"
|
||||||
agent_config["toolgroups"].append(
|
|
||||||
dict(
|
|
||||||
name="builtin::memory",
|
|
||||||
args={"memory_bank_ids": [memory_bank_id]},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
|
||||||
llama_stack_client.memory_banks.register(
|
llama_stack_client.memory_banks.register(
|
||||||
memory_bank_id=memory_bank_id,
|
memory_bank_id=memory_bank_id,
|
||||||
params={
|
params={
|
||||||
|
@ -314,25 +299,28 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
bank_id=memory_bank_id,
|
bank_id=memory_bank_id,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
)
|
)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
agent_config = {
|
||||||
|
**agent_config,
|
||||||
|
"toolgroups": [
|
||||||
|
dict(
|
||||||
|
name="builtin::memory",
|
||||||
|
args={
|
||||||
|
"memory_bank_ids": [memory_bank_id],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
rag_agent = Agent(llama_stack_client, agent_config)
|
||||||
|
session_id = rag_agent.create_session("test-session")
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
|
"What are the top 5 topics that were explained? Only list succinct bullet points.",
|
||||||
"Was anything related to 'Llama3' discussed, if so what?",
|
|
||||||
"Tell me how to use LoRA",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for prompt in user_prompts:
|
for prompt in user_prompts:
|
||||||
response = agent.create_turn(
|
print(f"User> {prompt}")
|
||||||
messages=[
|
response = rag_agent.create_turn(
|
||||||
{
|
messages=[{"role": "user", "content": prompt}],
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
assert "Tool:query_memory" in logs_str
|
assert "Tool:query_memory" in logs_str
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue