chore!: BREAKING CHANGE removing VectorDB APIs (#3774)

# What does this PR do?
Removes VectorDBs from API surface and our tests.

Moves tests to Vector Stores.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Francisco Arceo 2025-10-11 17:07:08 -04:00 committed by GitHub
parent 06e4cd8e02
commit a165b8b5bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
111 changed files with 60412 additions and 2765 deletions

View file

@ -96,7 +96,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
:cvar telemetry: Observability and system monitoring
:cvar models: Model metadata and management
:cvar shields: Safety shield implementations
:cvar vector_dbs: Vector database management
:cvar datasets: Dataset creation and management
:cvar scoring_functions: Scoring function definitions
:cvar benchmarks: Benchmark suite management
@ -122,7 +121,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
models = "models"
shields = "shields"
vector_dbs = "vector_dbs"
datasets = "datasets"
scoring_functions = "scoring_functions"
benchmarks = "benchmarks"

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from typing import Literal
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.schema_utils import json_schema_type
@json_schema_type
@ -61,57 +59,3 @@ class ListVectorDBsResponse(BaseModel):
"""
data: list[VectorDB]
@runtime_checkable
@trace_protocol
class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET", level=LLAMA_STACK_API_V1)
async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""List all vector databases.
:returns: A ListVectorDBsResponse.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_vector_db(
self,
vector_db_id: str,
) -> VectorDB:
"""Get a vector database by its identifier.
:param vector_db_id: The identifier of the vector database to get.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs", method="POST", level=LLAMA_STACK_API_V1)
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
"""Register a vector database.
:param vector_db_id: The identifier of the vector database to register.
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param vector_db_name: The name of the vector database.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database.
:param vector_db_id: The identifier of the vector database to unregister.
"""
...

View file

@ -47,10 +47,6 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
routing_table_api=Api.shields,
router_api=Api.safety,
),
AutoRoutedApiInfo(
routing_table_api=Api.vector_dbs,
router_api=Api.vector_io,
),
AutoRoutedApiInfo(
routing_table_api=Api.datasets,
router_api=Api.datasetio,

View file

@ -28,7 +28,6 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
@ -55,7 +54,6 @@ from llama_stack.providers.datatypes import (
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate,
)
logger = get_logger(name=__name__, category="core")
@ -81,7 +79,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.inspect: Inspect,
Api.batches: Batches,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
@ -125,7 +122,6 @@ def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (

View file

@ -26,10 +26,8 @@ async def get_routing_table_impl(
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
api_to_tables = {
"vector_dbs": VectorDBsRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,

View file

@ -134,15 +134,12 @@ class CommonRoutingTableImpl(RoutingTable):
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
from .vector_dbs import VectorDBsRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):

View file

@ -1,309 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import (
VectorDBWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise VectorStoreNotFoundError(vector_db_id)
return vector_db
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await lookup_model(self, embedding_model)
if model is None:
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
provider = self.impls_by_provider_id[provider_id]
logger.warning(
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
)
vector_store = await provider.openai_create_vector_store(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id,
)
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = {
"identifier": vector_store_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_store.name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def unregister_vector_db(self, vector_db_id: str) -> None:
existing_vector_db = await self.get_vector_db(vector_db_id)
await self.unregister_object(existing_vector_db)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id)
return result
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: Any | None = None,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
file_ids=file_ids,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
before=before,
filter=filter,
limit=limit,
order=order,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -33,7 +33,6 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
@ -53,7 +52,6 @@ logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
VectorDBs,
Inference,
Agents,
Safety,
@ -83,7 +81,6 @@ class LlamaStack(
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",

View file

@ -11,19 +11,17 @@ from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.core.ui.page.distribution.models import models
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.core.ui.page.distribution.shields import shields
from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs
def resources_page():
options = [
"Models",
"Vector Databases",
"Shields",
"Scoring Functions",
"Datasets",
"Benchmarks",
]
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
selected_resource = option_menu(
None,
options,
@ -37,8 +35,6 @@ def resources_page():
)
if selected_resource == "Benchmarks":
benchmarks()
elif selected_resource == "Vector Databases":
vector_dbs()
elif selected_resource == "Datasets":
datasets()
elif selected_resource == "Models":

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def vector_dbs():
st.header("Vector Databases")
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
if len(vector_dbs_info) > 0:
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
st.json(vector_dbs_info[selected_vector_db])
else:
st.info("No vector databases found")

View file

@ -1,301 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import streamlit as st
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack.apis.common.content_types import ToolCallDelta
from llama_stack.core.ui.modules.api import llama_stack_api
from llama_stack.core.ui.modules.utils import data_url_from_file
def rag_chat_page():
st.title("🦙 RAG")
def reset_agent_and_chat():
st.session_state.clear()
st.cache_resource.clear()
def should_disable_input():
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
def log_message(message):
with st.chat_message(message["role"]):
if "tool_output" in message and message["tool_output"]:
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
st.write(message["tool_output"])
st.markdown(message["content"])
with st.sidebar:
# File/Directory Upload Section
st.subheader("Upload Documents", divider=True)
uploaded_files = st.file_uploader(
"Upload file(s) or directory",
accept_multiple_files=True,
type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
)
# Process uploaded files
if uploaded_files:
st.success(f"Successfully uploaded {len(uploaded_files)} files")
# Add memory bank name input field
vector_db_name = st.text_input(
"Document Collection Name",
value="rag_vector_db",
help="Enter a unique identifier for this document collection",
)
if st.button("Create Document Collection"):
documents = [
RAGDocument(
document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file),
)
for i, uploaded_file in enumerate(uploaded_files)
]
providers = llama_stack_api.client.providers.list()
vector_io_provider = None
for x in providers:
if x.api == "vector_io":
vector_io_provider = x.provider_id
llama_stack_api.client.vector_dbs.register(
vector_db_id=vector_db_name, # Use the user-provided name
embedding_dimension=384,
embedding_model="all-MiniLM-L6-v2",
provider_id=vector_io_provider,
)
# insert documents using the custom vector db name
llama_stack_api.client.tool_runtime.rag_tool.insert(
vector_db_id=vector_db_name, # Use the user-provided name
documents=documents,
chunk_size_in_tokens=512,
)
st.success("Vector database created successfully!")
st.subheader("RAG Parameters", divider=True)
rag_mode = st.radio(
"RAG mode",
["Direct", "Agent-based"],
captions=[
"RAG is performed by directly retrieving the information and augmenting the user query",
"RAG is performed by an agent activating a dedicated knowledge search tool.",
],
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# select memory banks
vector_dbs = llama_stack_api.client.vector_dbs.list()
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect(
label="Select Document Collections to use in RAG queries",
options=vector_dbs,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
st.subheader("Inference Parameters", divider=True)
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
label="Choose a model",
options=available_models,
index=0,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful assistant. ",
help="Initial instructions given to the AI to set its behavior and context",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
reset_agent_and_chat()
st.rerun()
# Chat Interface
if "messages" not in st.session_state:
st.session_state.messages = []
if "displayed_messages" not in st.session_state:
st.session_state.displayed_messages = []
# Display chat history
for message in st.session_state.displayed_messages:
log_message(message)
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
@st.cache_resource
def create_agent():
return Agent(
llama_stack_api.client,
model=selected_model,
instructions=system_prompt,
sampling_params={
"strategy": strategy,
},
tools=[
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": list(selected_vector_dbs),
},
)
],
)
if rag_mode == "Agent-based":
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
def agent_process_prompt(prompt):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Send the prompt to the agent
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
# Display assistant response
with st.chat_message("assistant"):
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
for log in AgentEventLogger().log(response):
log.print()
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.write(retrieval_response)
else:
full_response += log.content
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.session_state.displayed_messages.append(
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
)
def direct_process_prompt(prompt):
# Add the system prompt in the beginning of the conversation
if len(st.session_state.messages) == 0:
st.session_state.messages.append({"role": "system", "content": system_prompt})
# Query the vector DB
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
content=prompt, vector_db_ids=list(selected_vector_dbs)
)
prompt_context = rag_response.content
with st.chat_message("assistant"):
with st.expander(label="Retrieval Output", expanded=False):
st.write(prompt_context)
retrieval_message_placeholder = st.empty()
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
# Construct the extended prompt
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
# Run inference directly
st.session_state.messages.append({"role": "user", "content": extended_prompt})
response = llama_stack_api.client.inference.chat_completion(
messages=st.session_state.messages,
model_id=selected_model,
sampling_params={
"strategy": strategy,
},
stream=True,
)
# Display assistant response
for chunk in response:
response_delta = chunk.event.delta
if isinstance(response_delta, ToolCallDelta):
retrieval_response += response_delta.tool_call.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
else:
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
st.session_state.messages.append(response_dict)
st.session_state.displayed_messages.append(response_dict)
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
# Add user message to chat history
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# store the prompt to process it after page refresh
st.session_state.prompt = prompt
# force page refresh to disable the settings widgets
st.rerun()
if "prompt" in st.session_state and st.session_state.prompt is not None:
if rag_mode == "Agent-based":
agent_process_prompt(st.session_state.prompt)
else: # rag_mode == "Direct"
direct_process_prompt(st.session_state.prompt)
st.session_state.prompt = None
rag_chat_page()

View file

@ -32,7 +32,6 @@ def available_providers() -> list[ProviderSpec]:
Api.inference,
Api.safety,
Api.vector_io,
Api.vector_dbs,
Api.tool_runtime,
Api.tool_groups,
Api.conversations,