chore!: BREAKING CHANGE removing VectorDB APIs (#3774)

# What does this PR do?
Removes VectorDBs from API surface and our tests.

Moves tests to Vector Stores.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Francisco Arceo 2025-10-11 17:07:08 -04:00 committed by GitHub
parent 06e4cd8e02
commit a165b8b5bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
111 changed files with 60412 additions and 2765 deletions

View file

@ -8,7 +8,6 @@ from typing import Any
from uuid import uuid4
import pytest
import requests
from llama_stack_client import Agent, AgentEventLogger, Document
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
@ -443,118 +442,6 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
return [step for step in response.steps if step.step_type == "tool_execution"]
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
# small chunks help to get specific info out of the docs
chunk_size_in_tokens=256,
)
agent_config = {
**agent_config,
"tools": [
dict(
name=rag_tool_name,
args={
"vector_db_ids": [vector_db_id],
},
)
],
}
rag_agent = Agent(llama_stack_client, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
]
for prompt, expected_kw in user_prompts:
response = rag_agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
stream=False,
)
# rag is called
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
# document ids are present in metadata
assert all(
doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
)
if expected_kw:
assert expected_kw in response.output_message.content.lower()
def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
urls = ["llama3.rst", "lora_finetune.rst"]
documents = [
# passign as url
Document(
document_id="num-0",
content={
"type": "url",
"uri": f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{urls[0]}",
},
mime_type="text/plain",
metadata={},
),
# passing as str
Document(
document_id="num-1",
content=requests.get(
f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{urls[1]}"
).text[:500],
mime_type="text/plain",
metadata={},
),
]
rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
documents,
),
(
"Tell me how to use LoRA in 100 words or less",
None,
),
]
for prompt in user_prompts:
response = rag_agent.create_turn(
messages=[
{
"role": "user",
"content": prompt[0],
}
],
documents=prompt[1],
session_id=session_id,
stream=False,
)
assert "lora" in response.output_message.content.lower()
@pytest.mark.parametrize(
"client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],