mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
fix failing code interpreter tests
This commit is contained in:
parent
82395ba654
commit
db2ec110a1
2 changed files with 19 additions and 32 deletions
|
@ -78,7 +78,6 @@ def make_random_string(length: int = 8):
|
|||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_TOOL_GROUP_ID = "builtin::memory"
|
||||
MEMORY_QUERY_TOOL = "query_memory"
|
||||
CODE_INTERPRETER_TOOL = "code_interpreter"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
|
||||
|
||||
|
@ -787,7 +786,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
||||
code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None)
|
||||
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
|
|
|
@ -275,14 +275,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
|
||||
|
||||
def test_rag_agent(llama_stack_client, agent_config):
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
document_id=f"num-{i}",
|
||||
|
@ -292,15 +285,7 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
memory_bank_id = "test-memory-bank"
|
||||
agent_config["toolgroups"].append(
|
||||
dict(
|
||||
name="builtin::memory",
|
||||
args={"memory_bank_ids": [memory_bank_id]},
|
||||
)
|
||||
)
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
llama_stack_client.memory_banks.register(
|
||||
memory_bank_id=memory_bank_id,
|
||||
params={
|
||||
|
@ -314,25 +299,28 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
bank_id=memory_bank_id,
|
||||
documents=documents,
|
||||
)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
dict(
|
||||
name="builtin::memory",
|
||||
args={
|
||||
"memory_bank_ids": [memory_bank_id],
|
||||
},
|
||||
)
|
||||
],
|
||||
}
|
||||
rag_agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = rag_agent.create_session("test-session")
|
||||
user_prompts = [
|
||||
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
|
||||
"Was anything related to 'Llama3' discussed, if so what?",
|
||||
"Tell me how to use LoRA",
|
||||
"What are the top 5 topics that were explained? Only list succinct bullet points.",
|
||||
]
|
||||
|
||||
for prompt in user_prompts:
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
print(f"User> {prompt}")
|
||||
response = rag_agent.create_turn(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "Tool:query_memory" in logs_str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue