remove vector_db

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-10 15:53:02 -04:00
parent aee4afb327
commit 27c2b2dd57
16 changed files with 4 additions and 2596 deletions

View file

@ -2987,165 +2987,6 @@
"deprecated": false
}
},
"/v1/vector-dbs": {
"get": {
"responses": {
"200": {
"description": "A ListVectorDBsResponse.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListVectorDBsResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"VectorDBs"
],
"summary": "List all vector databases.",
"description": "List all vector databases.",
"parameters": [],
"deprecated": false
},
"post": {
"responses": {
"200": {
"description": "A VectorDB.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/VectorDB"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"VectorDBs"
],
"summary": "Register a vector database.",
"description": "Register a vector database.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RegisterVectorDbRequest"
}
}
},
"required": true
},
"deprecated": false
}
},
"/v1/vector-dbs/{vector_db_id}": {
"get": {
"responses": {
"200": {
"description": "A VectorDB.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/VectorDB"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"VectorDBs"
],
"summary": "Get a vector database by its identifier.",
"description": "Get a vector database by its identifier.",
"parameters": [
{
"name": "vector_db_id",
"in": "path",
"description": "The identifier of the vector database to get.",
"required": true,
"schema": {
"type": "string"
}
}
],
"deprecated": false
},
"delete": {
"responses": {
"200": {
"description": "OK"
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"VectorDBs"
],
"summary": "Unregister a vector database.",
"description": "Unregister a vector database.",
"parameters": [
{
"name": "vector_db_id",
"in": "path",
"description": "The identifier of the vector database to unregister.",
"required": true,
"schema": {
"type": "string"
}
}
],
"deprecated": false
}
},
"/v1/vector-io/insert": {
"post": {
"responses": {
@ -11791,111 +11632,6 @@
],
"title": "RegisterToolGroupRequest"
},
"VectorDB": {
"type": "object",
"properties": {
"identifier": {
"type": "string"
},
"provider_resource_id": {
"type": "string"
},
"provider_id": {
"type": "string"
},
"type": {
"type": "string",
"enum": [
"model",
"shield",
"vector_db",
"dataset",
"scoring_function",
"benchmark",
"tool",
"tool_group",
"prompt"
],
"const": "vector_db",
"default": "vector_db",
"description": "Type of resource, always 'vector_db' for vector databases"
},
"embedding_model": {
"type": "string",
"description": "Name of the embedding model to use for vector generation"
},
"embedding_dimension": {
"type": "integer",
"description": "Dimension of the embedding vectors"
},
"vector_db_name": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"identifier",
"provider_id",
"type",
"embedding_model",
"embedding_dimension"
],
"title": "VectorDB",
"description": "Vector database resource for storing and querying vector embeddings."
},
"ListVectorDBsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/VectorDB"
},
"description": "List of vector databases"
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListVectorDBsResponse",
"description": "Response from listing vector databases."
},
"RegisterVectorDbRequest": {
"type": "object",
"properties": {
"vector_db_id": {
"type": "string",
"description": "The identifier of the vector database to register."
},
"embedding_model": {
"type": "string",
"description": "The embedding model to use."
},
"embedding_dimension": {
"type": "integer",
"description": "The dimension of the embedding model."
},
"provider_id": {
"type": "string",
"description": "The identifier of the provider."
},
"vector_db_name": {
"type": "string",
"description": "The name of the vector database."
},
"provider_vector_db_id": {
"type": "string",
"description": "The identifier of the vector database in the provider."
}
},
"additionalProperties": false,
"required": [
"vector_db_id",
"embedding_model"
],
"title": "RegisterVectorDbRequest"
},
"Chunk": {
"type": "object",
"properties": {
@ -13371,10 +13107,6 @@
"name": "ToolRuntime",
"description": ""
},
{
"name": "VectorDBs",
"description": ""
},
{
"name": "VectorIO",
"description": ""
@ -13400,7 +13132,6 @@
"Telemetry",
"ToolGroups",
"ToolRuntime",
"VectorDBs",
"VectorIO"
]
}

View file

@ -2275,120 +2275,6 @@ paths:
schema:
type: string
deprecated: false
/v1/vector-dbs:
get:
responses:
'200':
description: A ListVectorDBsResponse.
content:
application/json:
schema:
$ref: '#/components/schemas/ListVectorDBsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- VectorDBs
summary: List all vector databases.
description: List all vector databases.
parameters: []
deprecated: false
post:
responses:
'200':
description: A VectorDB.
content:
application/json:
schema:
$ref: '#/components/schemas/VectorDB'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- VectorDBs
summary: Register a vector database.
description: Register a vector database.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterVectorDbRequest'
required: true
deprecated: false
/v1/vector-dbs/{vector_db_id}:
get:
responses:
'200':
description: A VectorDB.
content:
application/json:
schema:
$ref: '#/components/schemas/VectorDB'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- VectorDBs
summary: Get a vector database by its identifier.
description: Get a vector database by its identifier.
parameters:
- name: vector_db_id
in: path
description: >-
The identifier of the vector database to get.
required: true
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- VectorDBs
summary: Unregister a vector database.
description: Unregister a vector database.
parameters:
- name: vector_db_id
in: path
description: >-
The identifier of the vector database to unregister.
required: true
schema:
type: string
deprecated: false
/v1/vector-io/insert:
post:
responses:
@ -8910,91 +8796,6 @@ components:
- toolgroup_id
- provider_id
title: RegisterToolGroupRequest
VectorDB:
type: object
properties:
identifier:
type: string
provider_resource_id:
type: string
provider_id:
type: string
type:
type: string
enum:
- model
- shield
- vector_db
- dataset
- scoring_function
- benchmark
- tool
- tool_group
- prompt
const: vector_db
default: vector_db
description: >-
Type of resource, always 'vector_db' for vector databases
embedding_model:
type: string
description: >-
Name of the embedding model to use for vector generation
embedding_dimension:
type: integer
description: Dimension of the embedding vectors
vector_db_name:
type: string
additionalProperties: false
required:
- identifier
- provider_id
- type
- embedding_model
- embedding_dimension
title: VectorDB
description: >-
Vector database resource for storing and querying vector embeddings.
ListVectorDBsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/VectorDB'
description: List of vector databases
additionalProperties: false
required:
- data
title: ListVectorDBsResponse
description: Response from listing vector databases.
RegisterVectorDbRequest:
type: object
properties:
vector_db_id:
type: string
description: >-
The identifier of the vector database to register.
embedding_model:
type: string
description: The embedding model to use.
embedding_dimension:
type: integer
description: The dimension of the embedding model.
provider_id:
type: string
description: The identifier of the provider.
vector_db_name:
type: string
description: The name of the vector database.
provider_vector_db_id:
type: string
description: >-
The identifier of the vector database in the provider.
additionalProperties: false
required:
- vector_db_id
- embedding_model
title: RegisterVectorDbRequest
Chunk:
type: object
properties:
@ -10164,8 +9965,6 @@ tags:
description: ''
- name: ToolRuntime
description: ''
- name: VectorDBs
description: ''
- name: VectorIO
description: ''
x-tagGroups:
@ -10187,5 +9986,4 @@ x-tagGroups:
- Telemetry
- ToolGroups
- ToolRuntime
- VectorDBs
- VectorIO

View file

@ -2987,165 +2987,6 @@
"deprecated": false
}
},
"/v1/vector-dbs": {
"get": {
"responses": {
"200": {
"description": "A ListVectorDBsResponse.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListVectorDBsResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"VectorDBs"
],
"summary": "List all vector databases.",
"description": "List all vector databases.",
"parameters": [],
"deprecated": false
},
"post": {
"responses": {
"200": {
"description": "A VectorDB.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/VectorDB"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"VectorDBs"
],
"summary": "Register a vector database.",
"description": "Register a vector database.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RegisterVectorDbRequest"
}
}
},
"required": true
},
"deprecated": false
}
},
"/v1/vector-dbs/{vector_db_id}": {
"get": {
"responses": {
"200": {
"description": "A VectorDB.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/VectorDB"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"VectorDBs"
],
"summary": "Get a vector database by its identifier.",
"description": "Get a vector database by its identifier.",
"parameters": [
{
"name": "vector_db_id",
"in": "path",
"description": "The identifier of the vector database to get.",
"required": true,
"schema": {
"type": "string"
}
}
],
"deprecated": false
},
"delete": {
"responses": {
"200": {
"description": "OK"
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"VectorDBs"
],
"summary": "Unregister a vector database.",
"description": "Unregister a vector database.",
"parameters": [
{
"name": "vector_db_id",
"in": "path",
"description": "The identifier of the vector database to unregister.",
"required": true,
"schema": {
"type": "string"
}
}
],
"deprecated": false
}
},
"/v1/vector-io/insert": {
"post": {
"responses": {
@ -13800,111 +13641,6 @@
],
"title": "RegisterToolGroupRequest"
},
"VectorDB": {
"type": "object",
"properties": {
"identifier": {
"type": "string"
},
"provider_resource_id": {
"type": "string"
},
"provider_id": {
"type": "string"
},
"type": {
"type": "string",
"enum": [
"model",
"shield",
"vector_db",
"dataset",
"scoring_function",
"benchmark",
"tool",
"tool_group",
"prompt"
],
"const": "vector_db",
"default": "vector_db",
"description": "Type of resource, always 'vector_db' for vector databases"
},
"embedding_model": {
"type": "string",
"description": "Name of the embedding model to use for vector generation"
},
"embedding_dimension": {
"type": "integer",
"description": "Dimension of the embedding vectors"
},
"vector_db_name": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"identifier",
"provider_id",
"type",
"embedding_model",
"embedding_dimension"
],
"title": "VectorDB",
"description": "Vector database resource for storing and querying vector embeddings."
},
"ListVectorDBsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/VectorDB"
},
"description": "List of vector databases"
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListVectorDBsResponse",
"description": "Response from listing vector databases."
},
"RegisterVectorDbRequest": {
"type": "object",
"properties": {
"vector_db_id": {
"type": "string",
"description": "The identifier of the vector database to register."
},
"embedding_model": {
"type": "string",
"description": "The embedding model to use."
},
"embedding_dimension": {
"type": "integer",
"description": "The dimension of the embedding model."
},
"provider_id": {
"type": "string",
"description": "The identifier of the provider."
},
"vector_db_name": {
"type": "string",
"description": "The name of the vector database."
},
"provider_vector_db_id": {
"type": "string",
"description": "The identifier of the vector database in the provider."
}
},
"additionalProperties": false,
"required": [
"vector_db_id",
"embedding_model"
],
"title": "RegisterVectorDbRequest"
},
"Chunk": {
"type": "object",
"properties": {
@ -18948,10 +18684,6 @@
"name": "ToolRuntime",
"description": ""
},
{
"name": "VectorDBs",
"description": ""
},
{
"name": "VectorIO",
"description": ""
@ -18982,7 +18714,6 @@
"Telemetry",
"ToolGroups",
"ToolRuntime",
"VectorDBs",
"VectorIO"
]
}

View file

@ -2278,120 +2278,6 @@ paths:
schema:
type: string
deprecated: false
/v1/vector-dbs:
get:
responses:
'200':
description: A ListVectorDBsResponse.
content:
application/json:
schema:
$ref: '#/components/schemas/ListVectorDBsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- VectorDBs
summary: List all vector databases.
description: List all vector databases.
parameters: []
deprecated: false
post:
responses:
'200':
description: A VectorDB.
content:
application/json:
schema:
$ref: '#/components/schemas/VectorDB'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- VectorDBs
summary: Register a vector database.
description: Register a vector database.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RegisterVectorDbRequest'
required: true
deprecated: false
/v1/vector-dbs/{vector_db_id}:
get:
responses:
'200':
description: A VectorDB.
content:
application/json:
schema:
$ref: '#/components/schemas/VectorDB'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- VectorDBs
summary: Get a vector database by its identifier.
description: Get a vector database by its identifier.
parameters:
- name: vector_db_id
in: path
description: >-
The identifier of the vector database to get.
required: true
schema:
type: string
deprecated: false
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- VectorDBs
summary: Unregister a vector database.
description: Unregister a vector database.
parameters:
- name: vector_db_id
in: path
description: >-
The identifier of the vector database to unregister.
required: true
schema:
type: string
deprecated: false
/v1/vector-io/insert:
post:
responses:
@ -10355,91 +10241,6 @@ components:
- toolgroup_id
- provider_id
title: RegisterToolGroupRequest
VectorDB:
type: object
properties:
identifier:
type: string
provider_resource_id:
type: string
provider_id:
type: string
type:
type: string
enum:
- model
- shield
- vector_db
- dataset
- scoring_function
- benchmark
- tool
- tool_group
- prompt
const: vector_db
default: vector_db
description: >-
Type of resource, always 'vector_db' for vector databases
embedding_model:
type: string
description: >-
Name of the embedding model to use for vector generation
embedding_dimension:
type: integer
description: Dimension of the embedding vectors
vector_db_name:
type: string
additionalProperties: false
required:
- identifier
- provider_id
- type
- embedding_model
- embedding_dimension
title: VectorDB
description: >-
Vector database resource for storing and querying vector embeddings.
ListVectorDBsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/VectorDB'
description: List of vector databases
additionalProperties: false
required:
- data
title: ListVectorDBsResponse
description: Response from listing vector databases.
RegisterVectorDbRequest:
type: object
properties:
vector_db_id:
type: string
description: >-
The identifier of the vector database to register.
embedding_model:
type: string
description: The embedding model to use.
embedding_dimension:
type: integer
description: The dimension of the embedding model.
provider_id:
type: string
description: The identifier of the provider.
vector_db_name:
type: string
description: The name of the vector database.
provider_vector_db_id:
type: string
description: >-
The identifier of the vector database in the provider.
additionalProperties: false
required:
- vector_db_id
- embedding_model
title: RegisterVectorDbRequest
Chunk:
type: object
properties:
@ -14212,8 +14013,6 @@ tags:
description: ''
- name: ToolRuntime
description: ''
- name: VectorDBs
description: ''
- name: VectorIO
description: ''
x-tagGroups:
@ -14240,5 +14039,4 @@ x-tagGroups:
- Telemetry
- ToolGroups
- ToolRuntime
- VectorDBs
- VectorIO

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from typing import Literal
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.schema_utils import json_schema_type
@json_schema_type
@ -61,57 +59,3 @@ class ListVectorDBsResponse(BaseModel):
"""
data: list[VectorDB]
@runtime_checkable
@trace_protocol
class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET", level=LLAMA_STACK_API_V1)
async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""List all vector databases.
:returns: A ListVectorDBsResponse.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_vector_db(
self,
vector_db_id: str,
) -> VectorDB:
"""Get a vector database by its identifier.
:param vector_db_id: The identifier of the vector database to get.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs", method="POST", level=LLAMA_STACK_API_V1)
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
"""Register a vector database.
:param vector_db_id: The identifier of the vector database to register.
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param vector_db_name: The name of the vector database.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database.
:param vector_db_id: The identifier of the vector database to unregister.
"""
...

View file

@ -28,7 +28,6 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
@ -81,7 +80,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.inspect: Inspect,
Api.batches: Batches,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
@ -125,7 +123,7 @@ def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.vector_io: (VectorDBsProtocolPrivate,),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (

View file

@ -26,10 +26,8 @@ async def get_routing_table_impl(
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
api_to_tables = {
"vector_dbs": VectorDBsRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,

View file

@ -134,15 +134,12 @@ class CommonRoutingTableImpl(RoutingTable):
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
from .vector_dbs import VectorDBsRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):

View file

@ -1,306 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import (
VectorDBWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise VectorStoreNotFoundError(vector_db_id)
return vector_db
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await lookup_model(self, embedding_model)
if model is None:
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
provider = self.impls_by_provider_id[provider_id]
vector_store = await provider.openai_create_vector_store(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id,
)
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = {
"identifier": vector_store_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_store.name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def unregister_vector_db(self, vector_db_id: str) -> None:
existing_vector_db = await self.get_vector_db(vector_db_id)
await self.unregister_object(existing_vector_db)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id)
return result
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: Any | None = None,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
file_ids=file_ids,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
before=before,
filter=filter,
limit=limit,
order=order,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -33,7 +33,6 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
@ -53,7 +52,6 @@ logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
VectorDBs,
Inference,
Agents,
Safety,
@ -83,7 +81,6 @@ class LlamaStack(
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",

View file

@ -11,19 +11,17 @@ from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.core.ui.page.distribution.models import models
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.core.ui.page.distribution.shields import shields
from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs
def resources_page():
options = [
"Models",
"Vector Databases",
"Shields",
"Scoring Functions",
"Datasets",
"Benchmarks",
]
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
selected_resource = option_menu(
None,
options,
@ -37,8 +35,6 @@ def resources_page():
)
if selected_resource == "Benchmarks":
benchmarks()
elif selected_resource == "Vector Databases":
vector_dbs()
elif selected_resource == "Datasets":
datasets()
elif selected_resource == "Models":

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import streamlit as st
from llama_stack.core.ui.modules.api import llama_stack_api
def vector_dbs():
st.header("Vector Databases")
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
if len(vector_dbs_info) > 0:
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
st.json(vector_dbs_info[selected_vector_db])
else:
st.info("No vector databases found")

View file

@ -1,301 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
import streamlit as st
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack.apis.common.content_types import ToolCallDelta
from llama_stack.core.ui.modules.api import llama_stack_api
from llama_stack.core.ui.modules.utils import data_url_from_file
def rag_chat_page():
st.title("🦙 RAG")
def reset_agent_and_chat():
st.session_state.clear()
st.cache_resource.clear()
def should_disable_input():
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
def log_message(message):
with st.chat_message(message["role"]):
if "tool_output" in message and message["tool_output"]:
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
st.write(message["tool_output"])
st.markdown(message["content"])
with st.sidebar:
# File/Directory Upload Section
st.subheader("Upload Documents", divider=True)
uploaded_files = st.file_uploader(
"Upload file(s) or directory",
accept_multiple_files=True,
type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
)
# Process uploaded files
if uploaded_files:
st.success(f"Successfully uploaded {len(uploaded_files)} files")
# Add memory bank name input field
vector_db_name = st.text_input(
"Document Collection Name",
value="rag_vector_db",
help="Enter a unique identifier for this document collection",
)
if st.button("Create Document Collection"):
documents = [
RAGDocument(
document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file),
)
for i, uploaded_file in enumerate(uploaded_files)
]
providers = llama_stack_api.client.providers.list()
vector_io_provider = None
for x in providers:
if x.api == "vector_io":
vector_io_provider = x.provider_id
llama_stack_api.client.vector_dbs.register(
vector_db_id=vector_db_name, # Use the user-provided name
embedding_dimension=384,
embedding_model="all-MiniLM-L6-v2",
provider_id=vector_io_provider,
)
# insert documents using the custom vector db name
llama_stack_api.client.tool_runtime.rag_tool.insert(
vector_db_id=vector_db_name, # Use the user-provided name
documents=documents,
chunk_size_in_tokens=512,
)
st.success("Vector database created successfully!")
st.subheader("RAG Parameters", divider=True)
rag_mode = st.radio(
"RAG mode",
["Direct", "Agent-based"],
captions=[
"RAG is performed by directly retrieving the information and augmenting the user query",
"RAG is performed by an agent activating a dedicated knowledge search tool.",
],
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# select memory banks
vector_dbs = llama_stack_api.client.vector_dbs.list()
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect(
label="Select Document Collections to use in RAG queries",
options=vector_dbs,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
st.subheader("Inference Parameters", divider=True)
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
label="Choose a model",
options=available_models,
index=0,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful assistant. ",
help="Initial instructions given to the AI to set its behavior and context",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
temperature = st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
top_p = st.slider(
"Top P",
min_value=0.0,
max_value=1.0,
value=0.95,
step=0.1,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
reset_agent_and_chat()
st.rerun()
# Chat Interface
if "messages" not in st.session_state:
st.session_state.messages = []
if "displayed_messages" not in st.session_state:
st.session_state.displayed_messages = []
# Display chat history
for message in st.session_state.displayed_messages:
log_message(message)
if temperature > 0.0:
strategy = {
"type": "top_p",
"temperature": temperature,
"top_p": top_p,
}
else:
strategy = {"type": "greedy"}
@st.cache_resource
def create_agent():
return Agent(
llama_stack_api.client,
model=selected_model,
instructions=system_prompt,
sampling_params={
"strategy": strategy,
},
tools=[
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": list(selected_vector_dbs),
},
)
],
)
if rag_mode == "Agent-based":
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
def agent_process_prompt(prompt):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Send the prompt to the agent
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
# Display assistant response
with st.chat_message("assistant"):
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
for log in AgentEventLogger().log(response):
log.print()
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.write(retrieval_response)
else:
full_response += log.content
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.session_state.displayed_messages.append(
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
)
def direct_process_prompt(prompt):
# Add the system prompt in the beginning of the conversation
if len(st.session_state.messages) == 0:
st.session_state.messages.append({"role": "system", "content": system_prompt})
# Query the vector DB
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
content=prompt, vector_db_ids=list(selected_vector_dbs)
)
prompt_context = rag_response.content
with st.chat_message("assistant"):
with st.expander(label="Retrieval Output", expanded=False):
st.write(prompt_context)
retrieval_message_placeholder = st.empty()
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
# Construct the extended prompt
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
# Run inference directly
st.session_state.messages.append({"role": "user", "content": extended_prompt})
response = llama_stack_api.client.inference.chat_completion(
messages=st.session_state.messages,
model_id=selected_model,
sampling_params={
"strategy": strategy,
},
stream=True,
)
# Display assistant response
for chunk in response:
response_delta = chunk.event.delta
if isinstance(response_delta, ToolCallDelta):
retrieval_response += response_delta.tool_call.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
else:
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
st.session_state.messages.append(response_dict)
st.session_state.displayed_messages.append(response_dict)
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
# Add user message to chat history
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# store the prompt to process it after page refresh
st.session_state.prompt = prompt
# force page refresh to disable the settings widgets
st.rerun()
if "prompt" in st.session_state and st.session_state.prompt is not None:
if rag_mode == "Agent-based":
agent_process_prompt(st.session_state.prompt)
else: # rag_mode == "Direct"
direct_process_prompt(st.session_state.prompt)
st.session_state.prompt = None
rag_chat_page()

View file

@ -8,7 +8,6 @@ from typing import Any
from uuid import uuid4
import pytest
import requests
from llama_stack_client import Agent, AgentEventLogger, Document
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
@ -443,118 +442,6 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
return [step for step in response.steps if step.step_type == "tool_execution"]
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
# small chunks help to get specific info out of the docs
chunk_size_in_tokens=256,
)
agent_config = {
**agent_config,
"tools": [
dict(
name=rag_tool_name,
args={
"vector_db_ids": [vector_db_id],
},
)
],
}
rag_agent = Agent(llama_stack_client, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
]
for prompt, expected_kw in user_prompts:
response = rag_agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
stream=False,
)
# rag is called
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
# document ids are present in metadata
assert all(
doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
)
if expected_kw:
assert expected_kw in response.output_message.content.lower()
def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
urls = ["llama3.rst", "lora_finetune.rst"]
documents = [
# passign as url
Document(
document_id="num-0",
content={
"type": "url",
"uri": f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{urls[0]}",
},
mime_type="text/plain",
metadata={},
),
# passing as str
Document(
document_id="num-1",
content=requests.get(
f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{urls[1]}"
).text[:500],
mime_type="text/plain",
metadata={},
),
]
rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
documents,
),
(
"Tell me how to use LoRA in 100 words or less",
None,
),
]
for prompt in user_prompts:
response = rag_agent.create_turn(
messages=[
{
"role": "user",
"content": prompt[0],
}
],
documents=prompt[1],
session_id=session_id,
stream=False,
)
assert "lora" in response.output_message.content.lower()
@pytest.mark.parametrize(
"client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],

View file

@ -1,459 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client import BadRequestError
from llama_stack_client.types import Document
@pytest.fixture(scope="function")
def client_with_empty_registry(client_with_models):
def clear_registry():
vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()]
for vector_db_id in vector_dbs:
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
clear_registry()
try:
client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime")
except Exception:
pass
yield client_with_models
clear_registry()
@pytest.fixture(scope="session")
def sample_documents():
return [
Document(
document_id="test-doc-1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
Document(
document_id="test-doc-2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
Document(
document_id="test-doc-3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
Document(
document_id="test-doc-4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},
),
]
def assert_valid_chunk_response(response):
assert len(response.chunks) > 0
assert len(response.scores) > 0
assert len(response.chunks) == len(response.scores)
for chunk in response.chunks:
assert isinstance(chunk.content, str)
def assert_valid_text_response(response):
assert len(response.content) > 0
assert all(isinstance(chunk.text, str) for chunk in response.content)
def test_vector_db_insert_inline_and_query(
client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension
):
vector_db_name = "test_vector_db"
vector_db = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_name,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
vector_db_id = vector_db.identifier
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=sample_documents,
chunk_size_in_tokens=512,
vector_db_id=vector_db_id,
)
# Query with a direct match
query1 = "programming language"
response1 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
query=query1,
)
assert_valid_chunk_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)
# Query with semantic similarity
query2 = "AI and brain-inspired computing"
response2 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
query=query2,
)
assert_valid_chunk_response(response2)
assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks)
# Query with limit on number of results (max_chunks=2)
query3 = "computer"
response3 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
query=query3,
params={"max_chunks": 2},
)
assert_valid_chunk_response(response3)
assert len(response3.chunks) <= 2
# Query with threshold on similarity score
query4 = "computer"
response4 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
query=query4,
params={"score_threshold": 0.01},
)
assert_valid_chunk_response(response4)
assert all(score >= 0.01 for score in response4.scores)
def test_vector_db_insert_from_url_and_query(
client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension
):
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
assert len(providers) > 0
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
# list to check memory bank is successfully registered
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
# VectorDB is being migrated to VectorStore, so the ID will be different
# Just check that at least one vector DB was registered
assert len(available_vector_dbs) > 0
# Use the actual registered vector_db_id for subsequent operations
actual_vector_db_id = available_vector_dbs[0]
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=512,
)
# Query for the name of method
response1 = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_db_id,
query="What's the name of the fine-tunning method used?",
)
assert_valid_chunk_response(response1)
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
# Query for the name of model
response2 = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_db_id,
query="Which Llama model is mentioned?",
)
assert_valid_chunk_response(response2)
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
def test_rag_tool_openai_apis(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_openai_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
# different document formats that should work with OpenAI APIs
documents = [
Document(
document_id="text-doc",
content="This is a plain text document about machine learning algorithms.",
metadata={"type": "text", "category": "AI"},
),
Document(
document_id="url-doc",
content="https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst",
mime_type="text/plain",
metadata={"type": "url", "source": "pytorch"},
),
Document(
document_id="data-url-doc",
content="data:text/plain;base64,VGhpcyBpcyBhIGRhdGEgVVJMIGRvY3VtZW50IGFib3V0IGRlZXAgbGVhcm5pbmcu", # "This is a data URL document about deep learning."
metadata={"type": "data_url", "encoding": "base64"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
files_list = client_with_empty_registry.files.list()
assert len(files_list.data) >= len(documents), (
f"Expected at least {len(documents)} files, got {len(files_list.data)}"
)
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
vector_store_id=actual_vector_db_id
)
assert len(vector_store_files.data) >= len(documents), f"Expected at least {len(documents)} files in vector store"
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="Tell me about machine learning and deep learning",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "machine learning" in content_text or "deep learning" in content_text
def test_rag_tool_exception_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_exception_handling"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
documents = [
Document(
document_id="valid-doc",
content="This is a valid document that should be processed successfully.",
metadata={"status": "valid"},
),
Document(
document_id="invalid-url-doc",
content="https://nonexistent-domain-12345.com/invalid.txt",
metadata={"status": "invalid_url"},
),
Document(
document_id="another-valid-doc",
content="This is another valid document for testing resilience.",
metadata={"status": "valid"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="valid document",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "valid document" in content_text
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension):
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
assert len(providers) > 0
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
# VectorDB is being migrated to VectorStore, so the ID will be different
# Just check that at least one vector DB was registered
assert len(available_vector_dbs) > 0
# Use the actual registered vector_db_id for subsequent operations
actual_vector_db_id = available_vector_dbs[0]
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={"author": "llama", "source": url},
)
for i, url in enumerate(urls)
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=512,
)
response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="What is the name of the method used for fine-tuning?",
)
assert_valid_text_response(response_with_metadata)
assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="What is the name of the method used for fine-tuning?",
query_config={
"include_metadata_in_content": True,
"chunk_template": "Result {index}\nContent: {chunk.content}\n",
},
)
assert_valid_text_response(response_without_metadata)
assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content)
with pytest.raises((ValueError, BadRequestError)):
client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="What is the name of the method used for fine-tuning?",
query_config={
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
},
)
def test_rag_tool_query_generation(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_query_generation_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
documents = [
Document(
document_id="ai-doc",
content="Artificial intelligence and machine learning are transforming technology.",
metadata={"category": "AI"},
),
Document(
document_id="banana-doc",
content="Don't bring a banana to a knife fight.",
metadata={"category": "wisdom"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="Tell me about AI",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "artificial intelligence" in content_text or "machine learning" in content_text
def test_rag_tool_pdf_data_url_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_pdf_data_url_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
sample_pdf = b"%PDF-1.3\n3 0 obj\n<</Type /Page\n/Parent 1 0 R\n/Resources 2 0 R\n/Contents 4 0 R>>\nendobj\n4 0 obj\n<</Filter /FlateDecode /Length 115>>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<</Type /Pages\n/Kids [3 0 R ]\n/Count 1\n/MediaBox [0 0 595.28 841.89]\n>>\nendobj\n5 0 obj\n<</Type /Font\n/BaseFont /Helvetica\n/Subtype /Type1\n/Encoding /WinAnsiEncoding\n>>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Llama Stack Developers)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n"
import base64
pdf_base64 = base64.b64encode(sample_pdf).decode("utf-8")
pdf_data_url = f"data:application/pdf;base64,{pdf_base64}"
documents = [
Document(
document_id="test-pdf-data-url",
content=pdf_data_url,
metadata={"type": "pdf", "source": "data_url"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
files_list = client_with_empty_registry.files.list()
assert len(files_list.data) >= 1, "PDF should have been uploaded to Files API"
pdf_file = None
for file in files_list.data:
if file.filename and "test-pdf-data-url" in file.filename:
pdf_file = file
break
assert pdf_file is not None, "PDF file should be found in Files API"
assert pdf_file.bytes == len(sample_pdf), f"File size should match original PDF ({len(sample_pdf)} bytes)"
file_content = client_with_empty_registry.files.retrieve_content(pdf_file.id)
assert file_content.startswith(b"%PDF-"), "Retrieved file should be a valid PDF"
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
vector_store_id=actual_vector_db_id
)
assert len(vector_store_files.data) >= 1, "PDF should be attached to vector store"
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="sample title",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "sample title" in content_text or "title" in content_text

View file

@ -1,381 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Unit tests for the routing tables vector_dbs
import time
import uuid
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.apis.vector_io.vector_io import (
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.access_control.datatypes import AccessRule, Scope
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
self.vector_stores = {}
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_retrieve_vector_store(self, vector_store_id):
return VectorStoreObject(
id=vector_store_id,
name="Test Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_update_vector_store(self, vector_store_id, **kwargs):
return VectorStoreObject(
id=vector_store_id,
name="Updated Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_delete_vector_store(self, vector_store_id):
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
return VectorStoreSearchResponsePage(
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
)
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
return [
VectorStoreFileObject(
id="1",
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
]
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
return VectorStoreFileContentsResponse(
file_id=file_id,
filename="Sample File name",
attributes={"key": "value"},
content=[VectorStoreContent(type="text", text="Sample content")],
)
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def openai_create_vector_store(
self,
name=None,
embedding_model=None,
embedding_dimension=None,
provider_id=None,
provider_vector_db_id=None,
**kwargs,
):
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
vector_store = VectorStoreObject(
id=vector_store_id,
name=name or vector_store_id,
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
self.vector_stores[vector_store_id] = vector_store
return vector_store
async def openai_list_vector_stores(self, **kwargs):
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
return VectorStoreListResponse(
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
n = 10
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == len(vdb_dict)
vector_db_ids = {v.identifier for v in vector_dbs.data}
for k in vdb_dict:
assert vdb_dict[k].identifier in vector_db_ids
for k in vdb_dict:
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
n = 10
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
vector_db_ids = {v.identifier for v in vector_dbs.data}
vector_stores = await impl.openai_list_vector_stores()
vector_store_ids = {v.id for v in vector_stores.data}
assert vector_db_ids == vector_store_ids, (
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
)
for vector_store in vector_stores.data:
vector_db = await table.get_vector_db(vector_store.id)
assert vector_store.name == vector_db.vector_db_name, (
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
)
for vector_db_id in vector_db_ids:
await table.unregister_vector_db(vector_db_id)
assert len((await table.list_vector_dbs()).data) == 0
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
user_provided_id = "my-custom-vector-db"
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
vector_stores = await impl.openai_list_vector_stores()
assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0]
assert vector_store.name == user_provided_id
assert vector_store.id.startswith("vs_")
assert vector_store.id != user_provided_id
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 1
assert vector_dbs.data[0].identifier == vector_store.id
await table.unregister_vector_db(vector_store.id)
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
authorized_table = "vs1"
authorized_team = "team1"
unauthorized_team = "team2"
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user):
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
authorized_table = registered_vdb.identifier # Use the actual generated ID
# Authorized reader
with request_provider_data_context({}, authorized_user):
res = await table.openai_retrieve_vector_store(authorized_table)
assert res == "OK"
# Authorized updater
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
assert res == "UPDATED"
# Unauthorized reader
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_retrieve_vector_store(authorized_table)
# Unauthorized updater
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
# Authorized deleter
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
assert res == "DELETED"
# Unauthorized deleter
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
impl = VectorDBImpl()
policy = [
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
]
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
vector_db_id = "vs1"
file_id = "file-1"
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
with request_provider_data_context({}, admin_user):
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
vector_db_id = registered_vdb.identifier # Use the actual generated ID
read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
]
update_methods = [
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
]
delete_methods = [
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_delete_vector_store, (vector_db_id,), {}),
]
for user in [admin_user, read_only_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in read_methods:
result = await method(*args, **kwargs)
assert result is not None, f"Read operation failed with user {user.principal}"
with request_provider_data_context({}, no_access_user):
for method, args, kwargs in read_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)
with request_provider_data_context({}, admin_user):
for method, args, kwargs in update_methods:
result = await method(*args, **kwargs)
assert result is not None, "Update operation failed with admin user"
with request_provider_data_context({}, admin_user):
for method, args, kwargs in delete_methods:
result = await method(*args, **kwargs)
assert result is not None, "Delete operation failed with admin user"
for user in [read_only_user, no_access_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in delete_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)