mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
3: rename core routers with vector_store
This commit is contained in:
parent
18ff28b6f0
commit
3d7b463a80
11 changed files with 83 additions and 83 deletions
|
|
@ -121,7 +121,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
||||||
|
|
||||||
models = "models"
|
models = "models"
|
||||||
shields = "shields"
|
shields = "shields"
|
||||||
vector_dbs = "vector_dbs" # only used for routing
|
vector_stores = "vector_stores" # only used for routing table
|
||||||
datasets = "datasets"
|
datasets = "datasets"
|
||||||
scoring_functions = "scoring_functions"
|
scoring_functions = "scoring_functions"
|
||||||
benchmarks = "benchmarks"
|
benchmarks = "benchmarks"
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class AccessRule(BaseModel):
|
||||||
A rule defines a list of action either to permit or to forbid. It may specify a
|
A rule defines a list of action either to permit or to forbid. It may specify a
|
||||||
principal or a resource that must match for the rule to take effect. The resource
|
principal or a resource that must match for the rule to take effect. The resource
|
||||||
to match should be specified in the form of a type qualified identifier, e.g.
|
to match should be specified in the form of a type qualified identifier, e.g.
|
||||||
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
|
model::my-model or vector_store::some-db, or a wildcard for all resources of a type,
|
||||||
e.g. model::*. If the principal or resource are not specified, they will match all
|
e.g. model::*. If the principal or resource are not specified, they will match all
|
||||||
requests.
|
requests.
|
||||||
|
|
||||||
|
|
@ -79,9 +79,9 @@ class AccessRule(BaseModel):
|
||||||
description: any user has read access to any resource created by a member of their team
|
description: any user has read access to any resource created by a member of their team
|
||||||
- forbid:
|
- forbid:
|
||||||
actions: [create, read, delete]
|
actions: [create, read, delete]
|
||||||
resource: vector_db::*
|
resource: vector_store::*
|
||||||
unless: user with admin in roles
|
unless: user with admin in roles
|
||||||
description: only user with admin role can use vector_db resources
|
description: only user with admin role can use vector_store resources
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||||
from llama_stack.apis.shields import Shield, ShieldInput
|
from llama_stack.apis.shields import Shield, ShieldInput
|
||||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorStore, VectorStoreInput
|
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule
|
from llama_stack.core.access_control.datatypes import AccessRule
|
||||||
from llama_stack.core.storage.datatypes import (
|
from llama_stack.core.storage.datatypes import (
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
||||||
router_api=Api.tool_runtime,
|
router_api=Api.tool_runtime,
|
||||||
),
|
),
|
||||||
AutoRoutedApiInfo(
|
AutoRoutedApiInfo(
|
||||||
routing_table_api=Api.vector_dbs,
|
routing_table_api=Api.vector_stores,
|
||||||
router_api=Api.vector_io,
|
router_api=Api.vector_io,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDBs
|
from llama_stack.apis.vector_stores import VectorStores
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||||
from llama_stack.core.client import get_client_impl
|
from llama_stack.core.client import get_client_impl
|
||||||
|
|
@ -82,7 +82,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
||||||
Api.inspect: Inspect,
|
Api.inspect: Inspect,
|
||||||
Api.batches: Batches,
|
Api.batches: Batches,
|
||||||
Api.vector_io: VectorIO,
|
Api.vector_io: VectorIO,
|
||||||
Api.vector_dbs: VectorDBs,
|
Api.vector_stores: VectorStores,
|
||||||
Api.models: Models,
|
Api.models: Models,
|
||||||
Api.safety: Safety,
|
Api.safety: Safety,
|
||||||
Api.shields: Shields,
|
Api.shields: Shields,
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ async def get_routing_table_impl(
|
||||||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||||
from ..routing_tables.shields import ShieldsRoutingTable
|
from ..routing_tables.shields import ShieldsRoutingTable
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
from ..routing_tables.vector_stores import VectorStoresRoutingTable
|
||||||
|
|
||||||
api_to_tables = {
|
api_to_tables = {
|
||||||
"models": ModelsRoutingTable,
|
"models": ModelsRoutingTable,
|
||||||
|
|
@ -38,7 +38,7 @@ async def get_routing_table_impl(
|
||||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||||
"benchmarks": BenchmarksRoutingTable,
|
"benchmarks": BenchmarksRoutingTable,
|
||||||
"tool_groups": ToolGroupsRoutingTable,
|
"tool_groups": ToolGroupsRoutingTable,
|
||||||
"vector_dbs": VectorDBsRoutingTable,
|
"vector_stores": VectorStoresRoutingTable,
|
||||||
}
|
}
|
||||||
|
|
||||||
if api.value not in api_to_tables:
|
if api.value not in api_to_tables:
|
||||||
|
|
|
||||||
|
|
@ -37,24 +37,24 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
vector_db_ids: list[str],
|
vector_store_ids: list[str],
|
||||||
query_config: RAGQueryConfig | None = None,
|
query_config: RAGQueryConfig | None = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
|
||||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||||
return await provider.query(content, vector_db_ids, query_config)
|
return await provider.query(content, vector_store_ids, query_config)
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
self,
|
self,
|
||||||
documents: list[RAGDocument],
|
documents: list[RAGDocument],
|
||||||
vector_db_id: str,
|
vector_store_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
)
|
)
|
||||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||||
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
|
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -146,22 +146,22 @@ class VectorIORouter(VectorIO):
|
||||||
else:
|
else:
|
||||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
vector_store_id = f"vs_{uuid.uuid4()}"
|
||||||
registered_vector_db = await self.routing_table.register_vector_db(
|
registered_vector_store = await self.routing_table.register_vector_store(
|
||||||
vector_db_id=vector_db_id,
|
vector_store_id=vector_store_id,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_vector_db_id=vector_db_id,
|
provider_vector_store_id=vector_store_id,
|
||||||
vector_db_name=params.name,
|
vector_store_name=params.name,
|
||||||
)
|
)
|
||||||
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
|
provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)
|
||||||
|
|
||||||
# Update model_extra with registered values so provider uses the already-registered vector_db
|
# Update model_extra with registered values so provider uses the already-registered vector_store
|
||||||
if params.model_extra is None:
|
if params.model_extra is None:
|
||||||
params.model_extra = {}
|
params.model_extra = {}
|
||||||
params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id
|
params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
|
||||||
params.model_extra["provider_id"] = registered_vector_db.provider_id
|
params.model_extra["provider_id"] = registered_vector_store.provider_id
|
||||||
if embedding_model is not None:
|
if embedding_model is not None:
|
||||||
params.model_extra["embedding_model"] = embedding_model
|
params.model_extra["embedding_model"] = embedding_model
|
||||||
if embedding_dimension is not None:
|
if embedding_dimension is not None:
|
||||||
|
|
@ -179,15 +179,15 @@ class VectorIORouter(VectorIO):
|
||||||
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
||||||
# Route to default provider for now - could aggregate from all providers in the future
|
# Route to default provider for now - could aggregate from all providers in the future
|
||||||
# call retrieve on each vector dbs to get list of vector stores
|
# call retrieve on each vector dbs to get list of vector stores
|
||||||
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
vector_stores = await self.routing_table.get_all_with_type("vector_store")
|
||||||
all_stores = []
|
all_stores = []
|
||||||
for vector_db in vector_dbs:
|
for vector_store in vector_stores:
|
||||||
try:
|
try:
|
||||||
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
|
provider = await self.routing_table.get_provider_impl(vector_store.identifier)
|
||||||
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
|
vector_store = await provider.openai_retrieve_vector_store(vector_store.identifier)
|
||||||
all_stores.append(vector_store)
|
all_stores.append(vector_store)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Sort by created_at
|
# Sort by created_at
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
return await p.register_shield(obj)
|
return await p.register_shield(obj)
|
||||||
elif api == Api.vector_io:
|
elif api == Api.vector_io:
|
||||||
return await p.register_vector_db(obj)
|
return await p.register_vector_store(obj)
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
return await p.register_dataset(obj)
|
return await p.register_dataset(obj)
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
|
|
@ -57,7 +57,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
if api == Api.vector_io:
|
if api == Api.vector_io:
|
||||||
return await p.unregister_vector_db(obj.identifier)
|
return await p.unregister_vector_store(obj.identifier)
|
||||||
elif api == Api.inference:
|
elif api == Api.inference:
|
||||||
return await p.unregister_model(obj.identifier)
|
return await p.unregister_model(obj.identifier)
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
|
|
@ -108,7 +108,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
p.shield_store = self
|
p.shield_store = self
|
||||||
elif api == Api.vector_io:
|
elif api == Api.vector_io:
|
||||||
p.vector_db_store = self
|
p.vector_store_store = self
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
p.dataset_store = self
|
p.dataset_store = self
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
|
|
@ -134,15 +134,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
from .scoring_functions import ScoringFunctionsRoutingTable
|
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||||
from .shields import ShieldsRoutingTable
|
from .shields import ShieldsRoutingTable
|
||||||
from .toolgroups import ToolGroupsRoutingTable
|
from .toolgroups import ToolGroupsRoutingTable
|
||||||
from .vector_dbs import VectorDBsRoutingTable
|
from .vector_stores import VectorStoresRoutingTable
|
||||||
|
|
||||||
def apiname_object():
|
def apiname_object():
|
||||||
if isinstance(self, ModelsRoutingTable):
|
if isinstance(self, ModelsRoutingTable):
|
||||||
return ("Inference", "model")
|
return ("Inference", "model")
|
||||||
elif isinstance(self, ShieldsRoutingTable):
|
elif isinstance(self, ShieldsRoutingTable):
|
||||||
return ("Safety", "shield")
|
return ("Safety", "shield")
|
||||||
elif isinstance(self, VectorDBsRoutingTable):
|
elif isinstance(self, VectorStoresRoutingTable):
|
||||||
return ("VectorIO", "vector_db")
|
return ("VectorIO", "vector_store")
|
||||||
elif isinstance(self, DatasetsRoutingTable):
|
elif isinstance(self, DatasetsRoutingTable):
|
||||||
return ("DatasetIO", "dataset")
|
return ("DatasetIO", "dataset")
|
||||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
|
||||||
# Removed VectorDBs import to avoid exposing public API
|
# Removed VectorStores import to avoid exposing public API
|
||||||
from llama_stack.apis.vector_io.vector_io import (
|
from llama_stack.apis.vector_io.vector_io import (
|
||||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||||
SearchRankingOptions,
|
SearchRankingOptions,
|
||||||
|
|
@ -26,7 +26,7 @@ from llama_stack.apis.vector_io.vector_io import (
|
||||||
VectorStoreSearchResponsePage,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import (
|
from llama_stack.core.datatypes import (
|
||||||
VectorDBWithOwner,
|
VectorStoreWithOwner,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
|
@ -35,23 +35,23 @@ from .common import CommonRoutingTableImpl, lookup_model
|
||||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
||||||
"""Internal routing table for vector_db operations.
|
"""Internal routing table for vector_store operations.
|
||||||
|
|
||||||
Does not inherit from VectorDBs to avoid exposing public API endpoints.
|
Does not inherit from VectorStores to avoid exposing public API endpoints.
|
||||||
Only provides internal routing functionality for VectorIORouter.
|
Only provides internal routing functionality for VectorIORouter.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Internal methods only - no public API exposure
|
# Internal methods only - no public API exposure
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_store(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_store_id: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
embedding_dimension: int | None = 384,
|
embedding_dimension: int | None = 384,
|
||||||
provider_id: str | None = None,
|
provider_id: str | None = None,
|
||||||
provider_vector_db_id: str | None = None,
|
provider_vector_store_id: str | None = None,
|
||||||
vector_db_name: str | None = None,
|
vector_store_name: str | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
if len(self.impls_by_provider_id) > 0:
|
if len(self.impls_by_provider_id) > 0:
|
||||||
|
|
@ -78,41 +78,41 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
f"Provider '{provider_id}' not found in routing table. Available providers: {available_providers}"
|
f"Provider '{provider_id}' not found in routing table. Available providers: {available_providers}"
|
||||||
) from None
|
) from None
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
"VectorStore is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||||
)
|
)
|
||||||
request = OpenAICreateVectorStoreRequestWithExtraBody(
|
request = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||||
name=vector_db_name or vector_db_id,
|
name=vector_store_name or vector_store_id,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
embedding_dimension=model.metadata["embedding_dimension"],
|
embedding_dimension=model.metadata["embedding_dimension"],
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_vector_db_id=provider_vector_db_id,
|
provider_vector_store_id=provider_vector_store_id,
|
||||||
)
|
)
|
||||||
vector_store = await provider.openai_create_vector_store(request)
|
vector_store = await provider.openai_create_vector_store(request)
|
||||||
|
|
||||||
vector_store_id = vector_store.id
|
vector_store_id = vector_store.id
|
||||||
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
actual_provider_vector_store_id = provider_vector_store_id or vector_store_id
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
f"Ignoring vector_store_id {vector_store_id} and using vector_store_id {vector_store_id} instead. Setting VectorStore {vector_store_id} to VectorStore.vector_store_name"
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_db_data = {
|
vector_store_data = {
|
||||||
"identifier": vector_store_id,
|
"identifier": vector_store_id,
|
||||||
"type": ResourceType.vector_db.value,
|
"type": ResourceType.vector_store.value,
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"provider_resource_id": actual_provider_vector_db_id,
|
"provider_resource_id": actual_provider_vector_store_id,
|
||||||
"embedding_model": embedding_model,
|
"embedding_model": embedding_model,
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
"vector_db_name": vector_store.name,
|
"vector_store_name": vector_store.name,
|
||||||
}
|
}
|
||||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
vector_store = TypeAdapter(VectorStoreWithOwner).validate_python(vector_store_data)
|
||||||
await self.register_object(vector_db)
|
await self.register_object(vector_store)
|
||||||
return vector_db
|
return vector_store
|
||||||
|
|
||||||
async def openai_retrieve_vector_store(
|
async def openai_retrieve_vector_store(
|
||||||
self,
|
self,
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
) -> VectorStoreObject:
|
) -> VectorStoreObject:
|
||||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||||
|
|
||||||
|
|
@ -123,7 +123,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> VectorStoreObject:
|
) -> VectorStoreObject:
|
||||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_update_vector_store(
|
return await provider.openai_update_vector_store(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
|
|
@ -136,18 +136,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
self,
|
self,
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
) -> VectorStoreDeleteResponse:
|
) -> VectorStoreDeleteResponse:
|
||||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||||
await self.unregister_vector_db(vector_store_id)
|
await self.unregister_vector_store(vector_store_id)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_store_id: str) -> None:
|
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||||
"""Remove the vector store from the routing table registry."""
|
"""Remove the vector store from the routing table registry."""
|
||||||
try:
|
try:
|
||||||
vector_db_obj = await self.get_object_by_identifier("vector_db", vector_store_id)
|
vector_store_obj = await self.get_object_by_identifier("vector_store", vector_store_id)
|
||||||
if vector_db_obj:
|
if vector_store_obj:
|
||||||
await self.unregister_object(vector_db_obj)
|
await self.unregister_object(vector_store_obj)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the error but don't fail the operation
|
# Log the error but don't fail the operation
|
||||||
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
|
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
|
||||||
|
|
@ -162,7 +162,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
rewrite_query: bool | None = False,
|
rewrite_query: bool | None = False,
|
||||||
search_mode: str | None = "vector",
|
search_mode: str | None = "vector",
|
||||||
) -> VectorStoreSearchResponsePage:
|
) -> VectorStoreSearchResponsePage:
|
||||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_search_vector_store(
|
return await provider.openai_search_vector_store(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
|
|
@ -181,7 +181,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
attributes: dict[str, Any] | None = None,
|
attributes: dict[str, Any] | None = None,
|
||||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_attach_file_to_vector_store(
|
return await provider.openai_attach_file_to_vector_store(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
|
|
@ -199,7 +199,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
before: str | None = None,
|
before: str | None = None,
|
||||||
filter: VectorStoreFileStatus | None = None,
|
filter: VectorStoreFileStatus | None = None,
|
||||||
) -> list[VectorStoreFileObject]:
|
) -> list[VectorStoreFileObject]:
|
||||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_list_files_in_vector_store(
|
return await provider.openai_list_files_in_vector_store(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
|
|
@ -215,7 +215,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_retrieve_vector_store_file(
|
return await provider.openai_retrieve_vector_store_file(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
|
|
@ -227,7 +227,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileContentsResponse:
|
) -> VectorStoreFileContentsResponse:
|
||||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_retrieve_vector_store_file_contents(
|
return await provider.openai_retrieve_vector_store_file_contents(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
|
|
@ -240,7 +240,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
file_id: str,
|
file_id: str,
|
||||||
attributes: dict[str, Any],
|
attributes: dict[str, Any],
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_update_vector_store_file(
|
return await provider.openai_update_vector_store_file(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
|
|
@ -253,7 +253,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileDeleteResponse:
|
) -> VectorStoreFileDeleteResponse:
|
||||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_delete_vector_store_file(
|
return await provider.openai_delete_vector_store_file(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
|
|
@ -267,7 +267,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
attributes: dict[str, Any] | None = None,
|
attributes: dict[str, Any] | None = None,
|
||||||
chunking_strategy: Any | None = None,
|
chunking_strategy: Any | None = None,
|
||||||
):
|
):
|
||||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_create_vector_store_file_batch(
|
return await provider.openai_create_vector_store_file_batch(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
|
|
@ -281,7 +281,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
):
|
):
|
||||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_retrieve_vector_store_file_batch(
|
return await provider.openai_retrieve_vector_store_file_batch(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
|
|
@ -298,7 +298,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
limit: int | None = 20,
|
limit: int | None = 20,
|
||||||
order: str | None = "desc",
|
order: str | None = "desc",
|
||||||
):
|
):
|
||||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_list_files_in_vector_store_file_batch(
|
return await provider.openai_list_files_in_vector_store_file_batch(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
|
|
@ -315,7 +315,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
):
|
):
|
||||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_cancel_vector_store_file_batch(
|
return await provider.openai_cancel_vector_store_file_batch(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ def tool_chat_page():
|
||||||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||||
selected_vector_dbs = []
|
selected_vector_stores = []
|
||||||
|
|
||||||
def reset_agent():
|
def reset_agent():
|
||||||
st.session_state.clear()
|
st.session_state.clear()
|
||||||
|
|
@ -55,13 +55,13 @@ def tool_chat_page():
|
||||||
)
|
)
|
||||||
|
|
||||||
if "builtin::rag" in toolgroup_selection:
|
if "builtin::rag" in toolgroup_selection:
|
||||||
vector_dbs = llama_stack_api.client.vector_dbs.list() or []
|
vector_stores = llama_stack_api.client.vector_stores.list() or []
|
||||||
if not vector_dbs:
|
if not vector_stores:
|
||||||
st.info("No vector databases available for selection.")
|
st.info("No vector databases available for selection.")
|
||||||
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
vector_stores = [vector_store.identifier for vector_store in vector_stores]
|
||||||
selected_vector_dbs = st.multiselect(
|
selected_vector_stores = st.multiselect(
|
||||||
label="Select Document Collections to use in RAG queries",
|
label="Select Document Collections to use in RAG queries",
|
||||||
options=vector_dbs,
|
options=vector_stores,
|
||||||
on_change=reset_agent,
|
on_change=reset_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -119,7 +119,7 @@ def tool_chat_page():
|
||||||
tool_dict = dict(
|
tool_dict = dict(
|
||||||
name="builtin::rag",
|
name="builtin::rag",
|
||||||
args={
|
args={
|
||||||
"vector_db_ids": list(selected_vector_dbs),
|
"vector_store_ids": list(selected_vector_stores),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
toolgroup_selection[i] = tool_dict
|
toolgroup_selection[i] = tool_dict
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue