diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html
index 7e534f995..238a8afa7 100644
--- a/docs/static/llama-stack-spec.html
+++ b/docs/static/llama-stack-spec.html
@@ -2987,165 +2987,6 @@
"deprecated": false
}
},
- "/v1/vector-dbs": {
- "get": {
- "responses": {
- "200": {
- "description": "A ListVectorDBsResponse.",
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/ListVectorDBsResponse"
- }
- }
- }
- },
- "400": {
- "$ref": "#/components/responses/BadRequest400"
- },
- "429": {
- "$ref": "#/components/responses/TooManyRequests429"
- },
- "500": {
- "$ref": "#/components/responses/InternalServerError500"
- },
- "default": {
- "$ref": "#/components/responses/DefaultError"
- }
- },
- "tags": [
- "VectorDBs"
- ],
- "summary": "List all vector databases.",
- "description": "List all vector databases.",
- "parameters": [],
- "deprecated": false
- },
- "post": {
- "responses": {
- "200": {
- "description": "A VectorDB.",
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/VectorDB"
- }
- }
- }
- },
- "400": {
- "$ref": "#/components/responses/BadRequest400"
- },
- "429": {
- "$ref": "#/components/responses/TooManyRequests429"
- },
- "500": {
- "$ref": "#/components/responses/InternalServerError500"
- },
- "default": {
- "$ref": "#/components/responses/DefaultError"
- }
- },
- "tags": [
- "VectorDBs"
- ],
- "summary": "Register a vector database.",
- "description": "Register a vector database.",
- "parameters": [],
- "requestBody": {
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/RegisterVectorDbRequest"
- }
- }
- },
- "required": true
- },
- "deprecated": false
- }
- },
- "/v1/vector-dbs/{vector_db_id}": {
- "get": {
- "responses": {
- "200": {
- "description": "A VectorDB.",
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/VectorDB"
- }
- }
- }
- },
- "400": {
- "$ref": "#/components/responses/BadRequest400"
- },
- "429": {
- "$ref": "#/components/responses/TooManyRequests429"
- },
- "500": {
- "$ref": "#/components/responses/InternalServerError500"
- },
- "default": {
- "$ref": "#/components/responses/DefaultError"
- }
- },
- "tags": [
- "VectorDBs"
- ],
- "summary": "Get a vector database by its identifier.",
- "description": "Get a vector database by its identifier.",
- "parameters": [
- {
- "name": "vector_db_id",
- "in": "path",
- "description": "The identifier of the vector database to get.",
- "required": true,
- "schema": {
- "type": "string"
- }
- }
- ],
- "deprecated": false
- },
- "delete": {
- "responses": {
- "200": {
- "description": "OK"
- },
- "400": {
- "$ref": "#/components/responses/BadRequest400"
- },
- "429": {
- "$ref": "#/components/responses/TooManyRequests429"
- },
- "500": {
- "$ref": "#/components/responses/InternalServerError500"
- },
- "default": {
- "$ref": "#/components/responses/DefaultError"
- }
- },
- "tags": [
- "VectorDBs"
- ],
- "summary": "Unregister a vector database.",
- "description": "Unregister a vector database.",
- "parameters": [
- {
- "name": "vector_db_id",
- "in": "path",
- "description": "The identifier of the vector database to unregister.",
- "required": true,
- "schema": {
- "type": "string"
- }
- }
- ],
- "deprecated": false
- }
- },
"/v1/vector-io/insert": {
"post": {
"responses": {
@@ -11791,111 +11632,6 @@
],
"title": "RegisterToolGroupRequest"
},
- "VectorDB": {
- "type": "object",
- "properties": {
- "identifier": {
- "type": "string"
- },
- "provider_resource_id": {
- "type": "string"
- },
- "provider_id": {
- "type": "string"
- },
- "type": {
- "type": "string",
- "enum": [
- "model",
- "shield",
- "vector_db",
- "dataset",
- "scoring_function",
- "benchmark",
- "tool",
- "tool_group",
- "prompt"
- ],
- "const": "vector_db",
- "default": "vector_db",
- "description": "Type of resource, always 'vector_db' for vector databases"
- },
- "embedding_model": {
- "type": "string",
- "description": "Name of the embedding model to use for vector generation"
- },
- "embedding_dimension": {
- "type": "integer",
- "description": "Dimension of the embedding vectors"
- },
- "vector_db_name": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "identifier",
- "provider_id",
- "type",
- "embedding_model",
- "embedding_dimension"
- ],
- "title": "VectorDB",
- "description": "Vector database resource for storing and querying vector embeddings."
- },
- "ListVectorDBsResponse": {
- "type": "object",
- "properties": {
- "data": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/VectorDB"
- },
- "description": "List of vector databases"
- }
- },
- "additionalProperties": false,
- "required": [
- "data"
- ],
- "title": "ListVectorDBsResponse",
- "description": "Response from listing vector databases."
- },
- "RegisterVectorDbRequest": {
- "type": "object",
- "properties": {
- "vector_db_id": {
- "type": "string",
- "description": "The identifier of the vector database to register."
- },
- "embedding_model": {
- "type": "string",
- "description": "The embedding model to use."
- },
- "embedding_dimension": {
- "type": "integer",
- "description": "The dimension of the embedding model."
- },
- "provider_id": {
- "type": "string",
- "description": "The identifier of the provider."
- },
- "vector_db_name": {
- "type": "string",
- "description": "The name of the vector database."
- },
- "provider_vector_db_id": {
- "type": "string",
- "description": "The identifier of the vector database in the provider."
- }
- },
- "additionalProperties": false,
- "required": [
- "vector_db_id",
- "embedding_model"
- ],
- "title": "RegisterVectorDbRequest"
- },
"Chunk": {
"type": "object",
"properties": {
@@ -13371,10 +13107,6 @@
"name": "ToolRuntime",
"description": ""
},
- {
- "name": "VectorDBs",
- "description": ""
- },
{
"name": "VectorIO",
"description": ""
@@ -13400,7 +13132,6 @@
"Telemetry",
"ToolGroups",
"ToolRuntime",
- "VectorDBs",
"VectorIO"
]
}
diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml
index bad40c87d..6957afda8 100644
--- a/docs/static/llama-stack-spec.yaml
+++ b/docs/static/llama-stack-spec.yaml
@@ -2275,120 +2275,6 @@ paths:
schema:
type: string
deprecated: false
- /v1/vector-dbs:
- get:
- responses:
- '200':
- description: A ListVectorDBsResponse.
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/ListVectorDBsResponse'
- '400':
- $ref: '#/components/responses/BadRequest400'
- '429':
- $ref: >-
- #/components/responses/TooManyRequests429
- '500':
- $ref: >-
- #/components/responses/InternalServerError500
- default:
- $ref: '#/components/responses/DefaultError'
- tags:
- - VectorDBs
- summary: List all vector databases.
- description: List all vector databases.
- parameters: []
- deprecated: false
- post:
- responses:
- '200':
- description: A VectorDB.
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/VectorDB'
- '400':
- $ref: '#/components/responses/BadRequest400'
- '429':
- $ref: >-
- #/components/responses/TooManyRequests429
- '500':
- $ref: >-
- #/components/responses/InternalServerError500
- default:
- $ref: '#/components/responses/DefaultError'
- tags:
- - VectorDBs
- summary: Register a vector database.
- description: Register a vector database.
- parameters: []
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/RegisterVectorDbRequest'
- required: true
- deprecated: false
- /v1/vector-dbs/{vector_db_id}:
- get:
- responses:
- '200':
- description: A VectorDB.
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/VectorDB'
- '400':
- $ref: '#/components/responses/BadRequest400'
- '429':
- $ref: >-
- #/components/responses/TooManyRequests429
- '500':
- $ref: >-
- #/components/responses/InternalServerError500
- default:
- $ref: '#/components/responses/DefaultError'
- tags:
- - VectorDBs
- summary: Get a vector database by its identifier.
- description: Get a vector database by its identifier.
- parameters:
- - name: vector_db_id
- in: path
- description: >-
- The identifier of the vector database to get.
- required: true
- schema:
- type: string
- deprecated: false
- delete:
- responses:
- '200':
- description: OK
- '400':
- $ref: '#/components/responses/BadRequest400'
- '429':
- $ref: >-
- #/components/responses/TooManyRequests429
- '500':
- $ref: >-
- #/components/responses/InternalServerError500
- default:
- $ref: '#/components/responses/DefaultError'
- tags:
- - VectorDBs
- summary: Unregister a vector database.
- description: Unregister a vector database.
- parameters:
- - name: vector_db_id
- in: path
- description: >-
- The identifier of the vector database to unregister.
- required: true
- schema:
- type: string
- deprecated: false
/v1/vector-io/insert:
post:
responses:
@@ -8910,91 +8796,6 @@ components:
- toolgroup_id
- provider_id
title: RegisterToolGroupRequest
- VectorDB:
- type: object
- properties:
- identifier:
- type: string
- provider_resource_id:
- type: string
- provider_id:
- type: string
- type:
- type: string
- enum:
- - model
- - shield
- - vector_db
- - dataset
- - scoring_function
- - benchmark
- - tool
- - tool_group
- - prompt
- const: vector_db
- default: vector_db
- description: >-
- Type of resource, always 'vector_db' for vector databases
- embedding_model:
- type: string
- description: >-
- Name of the embedding model to use for vector generation
- embedding_dimension:
- type: integer
- description: Dimension of the embedding vectors
- vector_db_name:
- type: string
- additionalProperties: false
- required:
- - identifier
- - provider_id
- - type
- - embedding_model
- - embedding_dimension
- title: VectorDB
- description: >-
- Vector database resource for storing and querying vector embeddings.
- ListVectorDBsResponse:
- type: object
- properties:
- data:
- type: array
- items:
- $ref: '#/components/schemas/VectorDB'
- description: List of vector databases
- additionalProperties: false
- required:
- - data
- title: ListVectorDBsResponse
- description: Response from listing vector databases.
- RegisterVectorDbRequest:
- type: object
- properties:
- vector_db_id:
- type: string
- description: >-
- The identifier of the vector database to register.
- embedding_model:
- type: string
- description: The embedding model to use.
- embedding_dimension:
- type: integer
- description: The dimension of the embedding model.
- provider_id:
- type: string
- description: The identifier of the provider.
- vector_db_name:
- type: string
- description: The name of the vector database.
- provider_vector_db_id:
- type: string
- description: >-
- The identifier of the vector database in the provider.
- additionalProperties: false
- required:
- - vector_db_id
- - embedding_model
- title: RegisterVectorDbRequest
Chunk:
type: object
properties:
@@ -10164,8 +9965,6 @@ tags:
description: ''
- name: ToolRuntime
description: ''
- - name: VectorDBs
- description: ''
- name: VectorIO
description: ''
x-tagGroups:
@@ -10187,5 +9986,4 @@ x-tagGroups:
- Telemetry
- ToolGroups
- ToolRuntime
- - VectorDBs
- VectorIO
diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html
index 36c63367c..5c7a23e7b 100644
--- a/docs/static/stainless-llama-stack-spec.html
+++ b/docs/static/stainless-llama-stack-spec.html
@@ -2987,165 +2987,6 @@
"deprecated": false
}
},
- "/v1/vector-dbs": {
- "get": {
- "responses": {
- "200": {
- "description": "A ListVectorDBsResponse.",
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/ListVectorDBsResponse"
- }
- }
- }
- },
- "400": {
- "$ref": "#/components/responses/BadRequest400"
- },
- "429": {
- "$ref": "#/components/responses/TooManyRequests429"
- },
- "500": {
- "$ref": "#/components/responses/InternalServerError500"
- },
- "default": {
- "$ref": "#/components/responses/DefaultError"
- }
- },
- "tags": [
- "VectorDBs"
- ],
- "summary": "List all vector databases.",
- "description": "List all vector databases.",
- "parameters": [],
- "deprecated": false
- },
- "post": {
- "responses": {
- "200": {
- "description": "A VectorDB.",
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/VectorDB"
- }
- }
- }
- },
- "400": {
- "$ref": "#/components/responses/BadRequest400"
- },
- "429": {
- "$ref": "#/components/responses/TooManyRequests429"
- },
- "500": {
- "$ref": "#/components/responses/InternalServerError500"
- },
- "default": {
- "$ref": "#/components/responses/DefaultError"
- }
- },
- "tags": [
- "VectorDBs"
- ],
- "summary": "Register a vector database.",
- "description": "Register a vector database.",
- "parameters": [],
- "requestBody": {
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/RegisterVectorDbRequest"
- }
- }
- },
- "required": true
- },
- "deprecated": false
- }
- },
- "/v1/vector-dbs/{vector_db_id}": {
- "get": {
- "responses": {
- "200": {
- "description": "A VectorDB.",
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/VectorDB"
- }
- }
- }
- },
- "400": {
- "$ref": "#/components/responses/BadRequest400"
- },
- "429": {
- "$ref": "#/components/responses/TooManyRequests429"
- },
- "500": {
- "$ref": "#/components/responses/InternalServerError500"
- },
- "default": {
- "$ref": "#/components/responses/DefaultError"
- }
- },
- "tags": [
- "VectorDBs"
- ],
- "summary": "Get a vector database by its identifier.",
- "description": "Get a vector database by its identifier.",
- "parameters": [
- {
- "name": "vector_db_id",
- "in": "path",
- "description": "The identifier of the vector database to get.",
- "required": true,
- "schema": {
- "type": "string"
- }
- }
- ],
- "deprecated": false
- },
- "delete": {
- "responses": {
- "200": {
- "description": "OK"
- },
- "400": {
- "$ref": "#/components/responses/BadRequest400"
- },
- "429": {
- "$ref": "#/components/responses/TooManyRequests429"
- },
- "500": {
- "$ref": "#/components/responses/InternalServerError500"
- },
- "default": {
- "$ref": "#/components/responses/DefaultError"
- }
- },
- "tags": [
- "VectorDBs"
- ],
- "summary": "Unregister a vector database.",
- "description": "Unregister a vector database.",
- "parameters": [
- {
- "name": "vector_db_id",
- "in": "path",
- "description": "The identifier of the vector database to unregister.",
- "required": true,
- "schema": {
- "type": "string"
- }
- }
- ],
- "deprecated": false
- }
- },
"/v1/vector-io/insert": {
"post": {
"responses": {
@@ -13800,111 +13641,6 @@
],
"title": "RegisterToolGroupRequest"
},
- "VectorDB": {
- "type": "object",
- "properties": {
- "identifier": {
- "type": "string"
- },
- "provider_resource_id": {
- "type": "string"
- },
- "provider_id": {
- "type": "string"
- },
- "type": {
- "type": "string",
- "enum": [
- "model",
- "shield",
- "vector_db",
- "dataset",
- "scoring_function",
- "benchmark",
- "tool",
- "tool_group",
- "prompt"
- ],
- "const": "vector_db",
- "default": "vector_db",
- "description": "Type of resource, always 'vector_db' for vector databases"
- },
- "embedding_model": {
- "type": "string",
- "description": "Name of the embedding model to use for vector generation"
- },
- "embedding_dimension": {
- "type": "integer",
- "description": "Dimension of the embedding vectors"
- },
- "vector_db_name": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "identifier",
- "provider_id",
- "type",
- "embedding_model",
- "embedding_dimension"
- ],
- "title": "VectorDB",
- "description": "Vector database resource for storing and querying vector embeddings."
- },
- "ListVectorDBsResponse": {
- "type": "object",
- "properties": {
- "data": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/VectorDB"
- },
- "description": "List of vector databases"
- }
- },
- "additionalProperties": false,
- "required": [
- "data"
- ],
- "title": "ListVectorDBsResponse",
- "description": "Response from listing vector databases."
- },
- "RegisterVectorDbRequest": {
- "type": "object",
- "properties": {
- "vector_db_id": {
- "type": "string",
- "description": "The identifier of the vector database to register."
- },
- "embedding_model": {
- "type": "string",
- "description": "The embedding model to use."
- },
- "embedding_dimension": {
- "type": "integer",
- "description": "The dimension of the embedding model."
- },
- "provider_id": {
- "type": "string",
- "description": "The identifier of the provider."
- },
- "vector_db_name": {
- "type": "string",
- "description": "The name of the vector database."
- },
- "provider_vector_db_id": {
- "type": "string",
- "description": "The identifier of the vector database in the provider."
- }
- },
- "additionalProperties": false,
- "required": [
- "vector_db_id",
- "embedding_model"
- ],
- "title": "RegisterVectorDbRequest"
- },
"Chunk": {
"type": "object",
"properties": {
@@ -18948,10 +18684,6 @@
"name": "ToolRuntime",
"description": ""
},
- {
- "name": "VectorDBs",
- "description": ""
- },
{
"name": "VectorIO",
"description": ""
@@ -18982,7 +18714,6 @@
"Telemetry",
"ToolGroups",
"ToolRuntime",
- "VectorDBs",
"VectorIO"
]
}
diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml
index 4475cc8f0..45a76613e 100644
--- a/docs/static/stainless-llama-stack-spec.yaml
+++ b/docs/static/stainless-llama-stack-spec.yaml
@@ -2278,120 +2278,6 @@ paths:
schema:
type: string
deprecated: false
- /v1/vector-dbs:
- get:
- responses:
- '200':
- description: A ListVectorDBsResponse.
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/ListVectorDBsResponse'
- '400':
- $ref: '#/components/responses/BadRequest400'
- '429':
- $ref: >-
- #/components/responses/TooManyRequests429
- '500':
- $ref: >-
- #/components/responses/InternalServerError500
- default:
- $ref: '#/components/responses/DefaultError'
- tags:
- - VectorDBs
- summary: List all vector databases.
- description: List all vector databases.
- parameters: []
- deprecated: false
- post:
- responses:
- '200':
- description: A VectorDB.
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/VectorDB'
- '400':
- $ref: '#/components/responses/BadRequest400'
- '429':
- $ref: >-
- #/components/responses/TooManyRequests429
- '500':
- $ref: >-
- #/components/responses/InternalServerError500
- default:
- $ref: '#/components/responses/DefaultError'
- tags:
- - VectorDBs
- summary: Register a vector database.
- description: Register a vector database.
- parameters: []
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/RegisterVectorDbRequest'
- required: true
- deprecated: false
- /v1/vector-dbs/{vector_db_id}:
- get:
- responses:
- '200':
- description: A VectorDB.
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/VectorDB'
- '400':
- $ref: '#/components/responses/BadRequest400'
- '429':
- $ref: >-
- #/components/responses/TooManyRequests429
- '500':
- $ref: >-
- #/components/responses/InternalServerError500
- default:
- $ref: '#/components/responses/DefaultError'
- tags:
- - VectorDBs
- summary: Get a vector database by its identifier.
- description: Get a vector database by its identifier.
- parameters:
- - name: vector_db_id
- in: path
- description: >-
- The identifier of the vector database to get.
- required: true
- schema:
- type: string
- deprecated: false
- delete:
- responses:
- '200':
- description: OK
- '400':
- $ref: '#/components/responses/BadRequest400'
- '429':
- $ref: >-
- #/components/responses/TooManyRequests429
- '500':
- $ref: >-
- #/components/responses/InternalServerError500
- default:
- $ref: '#/components/responses/DefaultError'
- tags:
- - VectorDBs
- summary: Unregister a vector database.
- description: Unregister a vector database.
- parameters:
- - name: vector_db_id
- in: path
- description: >-
- The identifier of the vector database to unregister.
- required: true
- schema:
- type: string
- deprecated: false
/v1/vector-io/insert:
post:
responses:
@@ -10355,91 +10241,6 @@ components:
- toolgroup_id
- provider_id
title: RegisterToolGroupRequest
- VectorDB:
- type: object
- properties:
- identifier:
- type: string
- provider_resource_id:
- type: string
- provider_id:
- type: string
- type:
- type: string
- enum:
- - model
- - shield
- - vector_db
- - dataset
- - scoring_function
- - benchmark
- - tool
- - tool_group
- - prompt
- const: vector_db
- default: vector_db
- description: >-
- Type of resource, always 'vector_db' for vector databases
- embedding_model:
- type: string
- description: >-
- Name of the embedding model to use for vector generation
- embedding_dimension:
- type: integer
- description: Dimension of the embedding vectors
- vector_db_name:
- type: string
- additionalProperties: false
- required:
- - identifier
- - provider_id
- - type
- - embedding_model
- - embedding_dimension
- title: VectorDB
- description: >-
- Vector database resource for storing and querying vector embeddings.
- ListVectorDBsResponse:
- type: object
- properties:
- data:
- type: array
- items:
- $ref: '#/components/schemas/VectorDB'
- description: List of vector databases
- additionalProperties: false
- required:
- - data
- title: ListVectorDBsResponse
- description: Response from listing vector databases.
- RegisterVectorDbRequest:
- type: object
- properties:
- vector_db_id:
- type: string
- description: >-
- The identifier of the vector database to register.
- embedding_model:
- type: string
- description: The embedding model to use.
- embedding_dimension:
- type: integer
- description: The dimension of the embedding model.
- provider_id:
- type: string
- description: The identifier of the provider.
- vector_db_name:
- type: string
- description: The name of the vector database.
- provider_vector_db_id:
- type: string
- description: >-
- The identifier of the vector database in the provider.
- additionalProperties: false
- required:
- - vector_db_id
- - embedding_model
- title: RegisterVectorDbRequest
Chunk:
type: object
properties:
@@ -14212,8 +14013,6 @@ tags:
description: ''
- name: ToolRuntime
description: ''
- - name: VectorDBs
- description: ''
- name: VectorIO
description: ''
x-tagGroups:
@@ -14240,5 +14039,4 @@ x-tagGroups:
- Telemetry
- ToolGroups
- ToolRuntime
- - VectorDBs
- VectorIO
diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py
index 521d129c6..53bf181e9 100644
--- a/llama_stack/apis/vector_dbs/vector_dbs.py
+++ b/llama_stack/apis/vector_dbs/vector_dbs.py
@@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from typing import Literal, Protocol, runtime_checkable
+from typing import Literal
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
-from llama_stack.apis.version import LLAMA_STACK_API_V1
-from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
-from llama_stack.schema_utils import json_schema_type, webmethod
+from llama_stack.schema_utils import json_schema_type
@json_schema_type
@@ -61,57 +59,3 @@ class ListVectorDBsResponse(BaseModel):
"""
data: list[VectorDB]
-
-
-@runtime_checkable
-@trace_protocol
-class VectorDBs(Protocol):
- @webmethod(route="/vector-dbs", method="GET", level=LLAMA_STACK_API_V1)
- async def list_vector_dbs(self) -> ListVectorDBsResponse:
- """List all vector databases.
-
- :returns: A ListVectorDBsResponse.
- """
- ...
-
- @webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET", level=LLAMA_STACK_API_V1)
- async def get_vector_db(
- self,
- vector_db_id: str,
- ) -> VectorDB:
- """Get a vector database by its identifier.
-
- :param vector_db_id: The identifier of the vector database to get.
- :returns: A VectorDB.
- """
- ...
-
- @webmethod(route="/vector-dbs", method="POST", level=LLAMA_STACK_API_V1)
- async def register_vector_db(
- self,
- vector_db_id: str,
- embedding_model: str,
- embedding_dimension: int | None = 384,
- provider_id: str | None = None,
- vector_db_name: str | None = None,
- provider_vector_db_id: str | None = None,
- ) -> VectorDB:
- """Register a vector database.
-
- :param vector_db_id: The identifier of the vector database to register.
- :param embedding_model: The embedding model to use.
- :param embedding_dimension: The dimension of the embedding model.
- :param provider_id: The identifier of the provider.
- :param vector_db_name: The name of the vector database.
- :param provider_vector_db_id: The identifier of the vector database in the provider.
- :returns: A VectorDB.
- """
- ...
-
- @webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
- async def unregister_vector_db(self, vector_db_id: str) -> None:
- """Unregister a vector database.
-
- :param vector_db_id: The identifier of the vector database to unregister.
- """
- ...
diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py
index 749253865..c8011307d 100644
--- a/llama_stack/core/resolver.py
+++ b/llama_stack/core/resolver.py
@@ -28,7 +28,6 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
-from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
@@ -81,7 +80,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.inspect: Inspect,
Api.batches: Batches,
Api.vector_io: VectorIO,
- Api.vector_dbs: VectorDBs,
Api.models: Models,
Api.safety: Safety,
Api.shields: Shields,
@@ -125,7 +123,7 @@ def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
- Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
+ Api.vector_io: (VectorDBsProtocolPrivate,),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (
diff --git a/llama_stack/core/routers/__init__.py b/llama_stack/core/routers/__init__.py
index f129f8ede..a1a8b0144 100644
--- a/llama_stack/core/routers/__init__.py
+++ b/llama_stack/core/routers/__init__.py
@@ -26,10 +26,8 @@ async def get_routing_table_impl(
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
- from ..routing_tables.vector_dbs import VectorDBsRoutingTable
api_to_tables = {
- "vector_dbs": VectorDBsRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py
index 0800b909b..0b5aa7843 100644
--- a/llama_stack/core/routing_tables/common.py
+++ b/llama_stack/core/routing_tables/common.py
@@ -134,15 +134,12 @@ class CommonRoutingTableImpl(RoutingTable):
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
- from .vector_dbs import VectorDBsRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
- elif isinstance(self, VectorDBsRoutingTable):
- return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py
deleted file mode 100644
index 995e0351d..000000000
--- a/llama_stack/core/routing_tables/vector_dbs.py
+++ /dev/null
@@ -1,306 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-from typing import Any
-
-from pydantic import TypeAdapter
-
-from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
-from llama_stack.apis.models import ModelType
-from llama_stack.apis.resource import ResourceType
-from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
-from llama_stack.apis.vector_io.vector_io import (
- SearchRankingOptions,
- VectorStoreChunkingStrategy,
- VectorStoreDeleteResponse,
- VectorStoreFileContentsResponse,
- VectorStoreFileDeleteResponse,
- VectorStoreFileObject,
- VectorStoreFileStatus,
- VectorStoreObject,
- VectorStoreSearchResponsePage,
-)
-from llama_stack.core.datatypes import (
- VectorDBWithOwner,
-)
-from llama_stack.log import get_logger
-
-from .common import CommonRoutingTableImpl, lookup_model
-
-logger = get_logger(name=__name__, category="core::routing_tables")
-
-
-class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
- async def list_vector_dbs(self) -> ListVectorDBsResponse:
- return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
-
- async def get_vector_db(self, vector_db_id: str) -> VectorDB:
- vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
- if vector_db is None:
- raise VectorStoreNotFoundError(vector_db_id)
- return vector_db
-
- async def register_vector_db(
- self,
- vector_db_id: str,
- embedding_model: str,
- embedding_dimension: int | None = 384,
- provider_id: str | None = None,
- provider_vector_db_id: str | None = None,
- vector_db_name: str | None = None,
- ) -> VectorDB:
- if provider_id is None:
- if len(self.impls_by_provider_id) > 0:
- provider_id = list(self.impls_by_provider_id.keys())[0]
- if len(self.impls_by_provider_id) > 1:
- logger.warning(
- f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
- )
- else:
- raise ValueError("No provider available. Please configure a vector_io provider.")
- model = await lookup_model(self, embedding_model)
- if model is None:
- raise ModelNotFoundError(embedding_model)
- if model.model_type != ModelType.embedding:
- raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
- if "embedding_dimension" not in model.metadata:
- raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
-
- provider = self.impls_by_provider_id[provider_id]
- vector_store = await provider.openai_create_vector_store(
- name=vector_db_name or vector_db_id,
- embedding_model=embedding_model,
- embedding_dimension=model.metadata["embedding_dimension"],
- provider_id=provider_id,
- provider_vector_db_id=provider_vector_db_id,
- )
-
- vector_store_id = vector_store.id
- actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
- logger.warning(
- f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
- )
-
- vector_db_data = {
- "identifier": vector_store_id,
- "type": ResourceType.vector_db.value,
- "provider_id": provider_id,
- "provider_resource_id": actual_provider_vector_db_id,
- "embedding_model": embedding_model,
- "embedding_dimension": model.metadata["embedding_dimension"],
- "vector_db_name": vector_store.name,
- }
- vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
- await self.register_object(vector_db)
- return vector_db
-
- async def unregister_vector_db(self, vector_db_id: str) -> None:
- existing_vector_db = await self.get_vector_db(vector_db_id)
- await self.unregister_object(existing_vector_db)
-
- async def openai_retrieve_vector_store(
- self,
- vector_store_id: str,
- ) -> VectorStoreObject:
- await self.assert_action_allowed("read", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_retrieve_vector_store(vector_store_id)
-
- async def openai_update_vector_store(
- self,
- vector_store_id: str,
- name: str | None = None,
- expires_after: dict[str, Any] | None = None,
- metadata: dict[str, Any] | None = None,
- ) -> VectorStoreObject:
- await self.assert_action_allowed("update", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_update_vector_store(
- vector_store_id=vector_store_id,
- name=name,
- expires_after=expires_after,
- metadata=metadata,
- )
-
- async def openai_delete_vector_store(
- self,
- vector_store_id: str,
- ) -> VectorStoreDeleteResponse:
- await self.assert_action_allowed("delete", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- result = await provider.openai_delete_vector_store(vector_store_id)
- await self.unregister_vector_db(vector_store_id)
- return result
-
- async def openai_search_vector_store(
- self,
- vector_store_id: str,
- query: str | list[str],
- filters: dict[str, Any] | None = None,
- max_num_results: int | None = 10,
- ranking_options: SearchRankingOptions | None = None,
- rewrite_query: bool | None = False,
- search_mode: str | None = "vector",
- ) -> VectorStoreSearchResponsePage:
- await self.assert_action_allowed("read", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_search_vector_store(
- vector_store_id=vector_store_id,
- query=query,
- filters=filters,
- max_num_results=max_num_results,
- ranking_options=ranking_options,
- rewrite_query=rewrite_query,
- search_mode=search_mode,
- )
-
- async def openai_attach_file_to_vector_store(
- self,
- vector_store_id: str,
- file_id: str,
- attributes: dict[str, Any] | None = None,
- chunking_strategy: VectorStoreChunkingStrategy | None = None,
- ) -> VectorStoreFileObject:
- await self.assert_action_allowed("update", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_attach_file_to_vector_store(
- vector_store_id=vector_store_id,
- file_id=file_id,
- attributes=attributes,
- chunking_strategy=chunking_strategy,
- )
-
- async def openai_list_files_in_vector_store(
- self,
- vector_store_id: str,
- limit: int | None = 20,
- order: str | None = "desc",
- after: str | None = None,
- before: str | None = None,
- filter: VectorStoreFileStatus | None = None,
- ) -> list[VectorStoreFileObject]:
- await self.assert_action_allowed("read", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_list_files_in_vector_store(
- vector_store_id=vector_store_id,
- limit=limit,
- order=order,
- after=after,
- before=before,
- filter=filter,
- )
-
- async def openai_retrieve_vector_store_file(
- self,
- vector_store_id: str,
- file_id: str,
- ) -> VectorStoreFileObject:
- await self.assert_action_allowed("read", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_retrieve_vector_store_file(
- vector_store_id=vector_store_id,
- file_id=file_id,
- )
-
- async def openai_retrieve_vector_store_file_contents(
- self,
- vector_store_id: str,
- file_id: str,
- ) -> VectorStoreFileContentsResponse:
- await self.assert_action_allowed("read", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_retrieve_vector_store_file_contents(
- vector_store_id=vector_store_id,
- file_id=file_id,
- )
-
- async def openai_update_vector_store_file(
- self,
- vector_store_id: str,
- file_id: str,
- attributes: dict[str, Any],
- ) -> VectorStoreFileObject:
- await self.assert_action_allowed("update", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_update_vector_store_file(
- vector_store_id=vector_store_id,
- file_id=file_id,
- attributes=attributes,
- )
-
- async def openai_delete_vector_store_file(
- self,
- vector_store_id: str,
- file_id: str,
- ) -> VectorStoreFileDeleteResponse:
- await self.assert_action_allowed("delete", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_delete_vector_store_file(
- vector_store_id=vector_store_id,
- file_id=file_id,
- )
-
- async def openai_create_vector_store_file_batch(
- self,
- vector_store_id: str,
- file_ids: list[str],
- attributes: dict[str, Any] | None = None,
- chunking_strategy: Any | None = None,
- ):
- await self.assert_action_allowed("update", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_create_vector_store_file_batch(
- vector_store_id=vector_store_id,
- file_ids=file_ids,
- attributes=attributes,
- chunking_strategy=chunking_strategy,
- )
-
- async def openai_retrieve_vector_store_file_batch(
- self,
- batch_id: str,
- vector_store_id: str,
- ):
- await self.assert_action_allowed("read", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_retrieve_vector_store_file_batch(
- batch_id=batch_id,
- vector_store_id=vector_store_id,
- )
-
- async def openai_list_files_in_vector_store_file_batch(
- self,
- batch_id: str,
- vector_store_id: str,
- after: str | None = None,
- before: str | None = None,
- filter: str | None = None,
- limit: int | None = 20,
- order: str | None = "desc",
- ):
- await self.assert_action_allowed("read", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_list_files_in_vector_store_file_batch(
- batch_id=batch_id,
- vector_store_id=vector_store_id,
- after=after,
- before=before,
- filter=filter,
- limit=limit,
- order=order,
- )
-
- async def openai_cancel_vector_store_file_batch(
- self,
- batch_id: str,
- vector_store_id: str,
- ):
- await self.assert_action_allowed("update", "vector_db", vector_store_id)
- provider = await self.get_provider_impl(vector_store_id)
- return await provider.openai_cancel_vector_store_file_batch(
- batch_id=batch_id,
- vector_store_id=vector_store_id,
- )
diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py
index 2eab9344f..f161ac358 100644
--- a/llama_stack/core/stack.py
+++ b/llama_stack/core/stack.py
@@ -33,7 +33,6 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
-from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
@@ -53,7 +52,6 @@ logger = get_logger(name=__name__, category="core")
class LlamaStack(
Providers,
- VectorDBs,
Inference,
Agents,
Safety,
@@ -83,7 +81,6 @@ class LlamaStack(
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
- ("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
diff --git a/llama_stack/core/ui/page/distribution/resources.py b/llama_stack/core/ui/page/distribution/resources.py
index c56fcfff3..6e7122ceb 100644
--- a/llama_stack/core/ui/page/distribution/resources.py
+++ b/llama_stack/core/ui/page/distribution/resources.py
@@ -11,19 +11,17 @@ from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks
from llama_stack.core.ui.page.distribution.models import models
from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions
from llama_stack.core.ui.page.distribution.shields import shields
-from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs
def resources_page():
options = [
"Models",
- "Vector Databases",
"Shields",
"Scoring Functions",
"Datasets",
"Benchmarks",
]
- icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
+ icons = ["magic", "shield", "file-bar-graph", "database", "list-task"]
selected_resource = option_menu(
None,
options,
@@ -37,8 +35,6 @@ def resources_page():
)
if selected_resource == "Benchmarks":
benchmarks()
- elif selected_resource == "Vector Databases":
- vector_dbs()
elif selected_resource == "Datasets":
datasets()
elif selected_resource == "Models":
diff --git a/llama_stack/core/ui/page/distribution/vector_dbs.py b/llama_stack/core/ui/page/distribution/vector_dbs.py
deleted file mode 100644
index e81077d2a..000000000
--- a/llama_stack/core/ui/page/distribution/vector_dbs.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-import streamlit as st
-
-from llama_stack.core.ui.modules.api import llama_stack_api
-
-
-def vector_dbs():
- st.header("Vector Databases")
- vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
-
- if len(vector_dbs_info) > 0:
- selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
- st.json(vector_dbs_info[selected_vector_db])
- else:
- st.info("No vector databases found")
diff --git a/llama_stack/core/ui/page/playground/rag.py b/llama_stack/core/ui/page/playground/rag.py
deleted file mode 100644
index 2ffae1c33..000000000
--- a/llama_stack/core/ui/page/playground/rag.py
+++ /dev/null
@@ -1,301 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-import uuid
-
-import streamlit as st
-from llama_stack_client import Agent, AgentEventLogger, RAGDocument
-
-from llama_stack.apis.common.content_types import ToolCallDelta
-from llama_stack.core.ui.modules.api import llama_stack_api
-from llama_stack.core.ui.modules.utils import data_url_from_file
-
-
-def rag_chat_page():
- st.title("🦙 RAG")
-
- def reset_agent_and_chat():
- st.session_state.clear()
- st.cache_resource.clear()
-
- def should_disable_input():
- return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
-
- def log_message(message):
- with st.chat_message(message["role"]):
- if "tool_output" in message and message["tool_output"]:
- with st.expander(label="Tool Output", expanded=False, icon="🛠"):
- st.write(message["tool_output"])
- st.markdown(message["content"])
-
- with st.sidebar:
- # File/Directory Upload Section
- st.subheader("Upload Documents", divider=True)
- uploaded_files = st.file_uploader(
- "Upload file(s) or directory",
- accept_multiple_files=True,
- type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
- )
- # Process uploaded files
- if uploaded_files:
- st.success(f"Successfully uploaded {len(uploaded_files)} files")
- # Add memory bank name input field
- vector_db_name = st.text_input(
- "Document Collection Name",
- value="rag_vector_db",
- help="Enter a unique identifier for this document collection",
- )
- if st.button("Create Document Collection"):
- documents = [
- RAGDocument(
- document_id=uploaded_file.name,
- content=data_url_from_file(uploaded_file),
- )
- for i, uploaded_file in enumerate(uploaded_files)
- ]
-
- providers = llama_stack_api.client.providers.list()
- vector_io_provider = None
-
- for x in providers:
- if x.api == "vector_io":
- vector_io_provider = x.provider_id
-
- llama_stack_api.client.vector_dbs.register(
- vector_db_id=vector_db_name, # Use the user-provided name
- embedding_dimension=384,
- embedding_model="all-MiniLM-L6-v2",
- provider_id=vector_io_provider,
- )
-
- # insert documents using the custom vector db name
- llama_stack_api.client.tool_runtime.rag_tool.insert(
- vector_db_id=vector_db_name, # Use the user-provided name
- documents=documents,
- chunk_size_in_tokens=512,
- )
- st.success("Vector database created successfully!")
-
- st.subheader("RAG Parameters", divider=True)
-
- rag_mode = st.radio(
- "RAG mode",
- ["Direct", "Agent-based"],
- captions=[
- "RAG is performed by directly retrieving the information and augmenting the user query",
- "RAG is performed by an agent activating a dedicated knowledge search tool.",
- ],
- on_change=reset_agent_and_chat,
- disabled=should_disable_input(),
- )
-
- # select memory banks
- vector_dbs = llama_stack_api.client.vector_dbs.list()
- vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
- selected_vector_dbs = st.multiselect(
- label="Select Document Collections to use in RAG queries",
- options=vector_dbs,
- on_change=reset_agent_and_chat,
- disabled=should_disable_input(),
- )
-
- st.subheader("Inference Parameters", divider=True)
- available_models = llama_stack_api.client.models.list()
- available_models = [model.identifier for model in available_models if model.model_type == "llm"]
- selected_model = st.selectbox(
- label="Choose a model",
- options=available_models,
- index=0,
- on_change=reset_agent_and_chat,
- disabled=should_disable_input(),
- )
- system_prompt = st.text_area(
- "System Prompt",
- value="You are a helpful assistant. ",
- help="Initial instructions given to the AI to set its behavior and context",
- on_change=reset_agent_and_chat,
- disabled=should_disable_input(),
- )
- temperature = st.slider(
- "Temperature",
- min_value=0.0,
- max_value=1.0,
- value=0.0,
- step=0.1,
- help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
- on_change=reset_agent_and_chat,
- disabled=should_disable_input(),
- )
-
- top_p = st.slider(
- "Top P",
- min_value=0.0,
- max_value=1.0,
- value=0.95,
- step=0.1,
- on_change=reset_agent_and_chat,
- disabled=should_disable_input(),
- )
-
- # Add clear chat button to sidebar
- if st.button("Clear Chat", use_container_width=True):
- reset_agent_and_chat()
- st.rerun()
-
- # Chat Interface
- if "messages" not in st.session_state:
- st.session_state.messages = []
- if "displayed_messages" not in st.session_state:
- st.session_state.displayed_messages = []
-
- # Display chat history
- for message in st.session_state.displayed_messages:
- log_message(message)
-
- if temperature > 0.0:
- strategy = {
- "type": "top_p",
- "temperature": temperature,
- "top_p": top_p,
- }
- else:
- strategy = {"type": "greedy"}
-
- @st.cache_resource
- def create_agent():
- return Agent(
- llama_stack_api.client,
- model=selected_model,
- instructions=system_prompt,
- sampling_params={
- "strategy": strategy,
- },
- tools=[
- dict(
- name="builtin::rag/knowledge_search",
- args={
- "vector_db_ids": list(selected_vector_dbs),
- },
- )
- ],
- )
-
- if rag_mode == "Agent-based":
- agent = create_agent()
- if "agent_session_id" not in st.session_state:
- st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
-
- session_id = st.session_state["agent_session_id"]
-
- def agent_process_prompt(prompt):
- # Add user message to chat history
- st.session_state.messages.append({"role": "user", "content": prompt})
-
- # Send the prompt to the agent
- response = agent.create_turn(
- messages=[
- {
- "role": "user",
- "content": prompt,
- }
- ],
- session_id=session_id,
- )
-
- # Display assistant response
- with st.chat_message("assistant"):
- retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
- message_placeholder = st.empty()
- full_response = ""
- retrieval_response = ""
- for log in AgentEventLogger().log(response):
- log.print()
- if log.role == "tool_execution":
- retrieval_response += log.content.replace("====", "").strip()
- retrieval_message_placeholder.write(retrieval_response)
- else:
- full_response += log.content
- message_placeholder.markdown(full_response + "▌")
- message_placeholder.markdown(full_response)
-
- st.session_state.messages.append({"role": "assistant", "content": full_response})
- st.session_state.displayed_messages.append(
- {"role": "assistant", "content": full_response, "tool_output": retrieval_response}
- )
-
- def direct_process_prompt(prompt):
- # Add the system prompt in the beginning of the conversation
- if len(st.session_state.messages) == 0:
- st.session_state.messages.append({"role": "system", "content": system_prompt})
-
- # Query the vector DB
- rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
- content=prompt, vector_db_ids=list(selected_vector_dbs)
- )
- prompt_context = rag_response.content
-
- with st.chat_message("assistant"):
- with st.expander(label="Retrieval Output", expanded=False):
- st.write(prompt_context)
-
- retrieval_message_placeholder = st.empty()
- message_placeholder = st.empty()
- full_response = ""
- retrieval_response = ""
-
- # Construct the extended prompt
- extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
-
- # Run inference directly
- st.session_state.messages.append({"role": "user", "content": extended_prompt})
- response = llama_stack_api.client.inference.chat_completion(
- messages=st.session_state.messages,
- model_id=selected_model,
- sampling_params={
- "strategy": strategy,
- },
- stream=True,
- )
-
- # Display assistant response
- for chunk in response:
- response_delta = chunk.event.delta
- if isinstance(response_delta, ToolCallDelta):
- retrieval_response += response_delta.tool_call.replace("====", "").strip()
- retrieval_message_placeholder.info(retrieval_response)
- else:
- full_response += chunk.event.delta.text
- message_placeholder.markdown(full_response + "▌")
- message_placeholder.markdown(full_response)
-
- response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
- st.session_state.messages.append(response_dict)
- st.session_state.displayed_messages.append(response_dict)
-
- # Chat input
- if prompt := st.chat_input("Ask a question about your documents"):
- # Add user message to chat history
- st.session_state.displayed_messages.append({"role": "user", "content": prompt})
-
- # Display user message
- with st.chat_message("user"):
- st.markdown(prompt)
-
- # store the prompt to process it after page refresh
- st.session_state.prompt = prompt
-
- # force page refresh to disable the settings widgets
- st.rerun()
-
- if "prompt" in st.session_state and st.session_state.prompt is not None:
- if rag_mode == "Agent-based":
- agent_process_prompt(st.session_state.prompt)
- else: # rag_mode == "Direct"
- direct_process_prompt(st.session_state.prompt)
- st.session_state.prompt = None
-
-
-rag_chat_page()
diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py
index 07ba7bb01..3542facef 100644
--- a/tests/integration/agents/test_agents.py
+++ b/tests/integration/agents/test_agents.py
@@ -8,7 +8,6 @@ from typing import Any
from uuid import uuid4
import pytest
-import requests
from llama_stack_client import Agent, AgentEventLogger, Document
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
@@ -443,118 +442,6 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
return [step for step in response.steps if step.step_type == "tool_execution"]
-@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
-def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
- urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
- documents = [
- Document(
- document_id=f"num-{i}",
- content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
- mime_type="text/plain",
- metadata={},
- )
- for i, url in enumerate(urls)
- ]
- vector_db_id = f"test-vector-db-{uuid4()}"
- llama_stack_client.vector_dbs.register(
- vector_db_id=vector_db_id,
- embedding_model="all-MiniLM-L6-v2",
- embedding_dimension=384,
- )
- llama_stack_client.tool_runtime.rag_tool.insert(
- documents=documents,
- vector_db_id=vector_db_id,
- # small chunks help to get specific info out of the docs
- chunk_size_in_tokens=256,
- )
- agent_config = {
- **agent_config,
- "tools": [
- dict(
- name=rag_tool_name,
- args={
- "vector_db_ids": [vector_db_id],
- },
- )
- ],
- }
- rag_agent = Agent(llama_stack_client, **agent_config)
- session_id = rag_agent.create_session(f"test-session-{uuid4()}")
- user_prompts = [
- (
- "Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
- "grouped",
- ),
- ]
- for prompt, expected_kw in user_prompts:
- response = rag_agent.create_turn(
- messages=[{"role": "user", "content": prompt}],
- session_id=session_id,
- stream=False,
- )
- # rag is called
- tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
- assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
- # document ids are present in metadata
- assert all(
- doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
- )
- if expected_kw:
- assert expected_kw in response.output_message.content.lower()
-
-
-def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
- urls = ["llama3.rst", "lora_finetune.rst"]
- documents = [
- # passign as url
- Document(
- document_id="num-0",
- content={
- "type": "url",
- "uri": f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{urls[0]}",
- },
- mime_type="text/plain",
- metadata={},
- ),
- # passing as str
- Document(
- document_id="num-1",
- content=requests.get(
- f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{urls[1]}"
- ).text[:500],
- mime_type="text/plain",
- metadata={},
- ),
- ]
- rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
- session_id = rag_agent.create_session(f"test-session-{uuid4()}")
- user_prompts = [
- (
- "I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
- documents,
- ),
- (
- "Tell me how to use LoRA in 100 words or less",
- None,
- ),
- ]
-
- for prompt in user_prompts:
- response = rag_agent.create_turn(
- messages=[
- {
- "role": "user",
- "content": prompt[0],
- }
- ],
- documents=prompt[1],
- session_id=session_id,
- stream=False,
- )
-
- assert "lora" in response.output_message.content.lower()
-
-
@pytest.mark.parametrize(
"client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py
deleted file mode 100644
index b78c39af8..000000000
--- a/tests/integration/tool_runtime/test_rag_tool.py
+++ /dev/null
@@ -1,459 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-import pytest
-from llama_stack_client import BadRequestError
-from llama_stack_client.types import Document
-
-
-@pytest.fixture(scope="function")
-def client_with_empty_registry(client_with_models):
- def clear_registry():
- vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()]
- for vector_db_id in vector_dbs:
- client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
-
- clear_registry()
-
- try:
- client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime")
- except Exception:
- pass
-
- yield client_with_models
-
- clear_registry()
-
-
-@pytest.fixture(scope="session")
-def sample_documents():
- return [
- Document(
- document_id="test-doc-1",
- content="Python is a high-level programming language.",
- metadata={"category": "programming", "difficulty": "beginner"},
- ),
- Document(
- document_id="test-doc-2",
- content="Machine learning is a subset of artificial intelligence.",
- metadata={"category": "AI", "difficulty": "advanced"},
- ),
- Document(
- document_id="test-doc-3",
- content="Data structures are fundamental to computer science.",
- metadata={"category": "computer science", "difficulty": "intermediate"},
- ),
- Document(
- document_id="test-doc-4",
- content="Neural networks are inspired by biological neural networks.",
- metadata={"category": "AI", "difficulty": "advanced"},
- ),
- ]
-
-
-def assert_valid_chunk_response(response):
- assert len(response.chunks) > 0
- assert len(response.scores) > 0
- assert len(response.chunks) == len(response.scores)
- for chunk in response.chunks:
- assert isinstance(chunk.content, str)
-
-
-def assert_valid_text_response(response):
- assert len(response.content) > 0
- assert all(isinstance(chunk.text, str) for chunk in response.content)
-
-
-def test_vector_db_insert_inline_and_query(
- client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension
-):
- vector_db_name = "test_vector_db"
- vector_db = client_with_empty_registry.vector_dbs.register(
- vector_db_id=vector_db_name,
- embedding_model=embedding_model_id,
- embedding_dimension=embedding_dimension,
- )
- vector_db_id = vector_db.identifier
-
- client_with_empty_registry.tool_runtime.rag_tool.insert(
- documents=sample_documents,
- chunk_size_in_tokens=512,
- vector_db_id=vector_db_id,
- )
-
- # Query with a direct match
- query1 = "programming language"
- response1 = client_with_empty_registry.vector_io.query(
- vector_db_id=vector_db_id,
- query=query1,
- )
- assert_valid_chunk_response(response1)
- assert any("Python" in chunk.content for chunk in response1.chunks)
-
- # Query with semantic similarity
- query2 = "AI and brain-inspired computing"
- response2 = client_with_empty_registry.vector_io.query(
- vector_db_id=vector_db_id,
- query=query2,
- )
- assert_valid_chunk_response(response2)
- assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks)
-
- # Query with limit on number of results (max_chunks=2)
- query3 = "computer"
- response3 = client_with_empty_registry.vector_io.query(
- vector_db_id=vector_db_id,
- query=query3,
- params={"max_chunks": 2},
- )
- assert_valid_chunk_response(response3)
- assert len(response3.chunks) <= 2
-
- # Query with threshold on similarity score
- query4 = "computer"
- response4 = client_with_empty_registry.vector_io.query(
- vector_db_id=vector_db_id,
- query=query4,
- params={"score_threshold": 0.01},
- )
- assert_valid_chunk_response(response4)
- assert all(score >= 0.01 for score in response4.scores)
-
-
-def test_vector_db_insert_from_url_and_query(
- client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension
-):
- providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
- assert len(providers) > 0
-
- vector_db_id = "test_vector_db"
-
- client_with_empty_registry.vector_dbs.register(
- vector_db_id=vector_db_id,
- embedding_model=embedding_model_id,
- embedding_dimension=embedding_dimension,
- )
-
- # list to check memory bank is successfully registered
- available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
- # VectorDB is being migrated to VectorStore, so the ID will be different
- # Just check that at least one vector DB was registered
- assert len(available_vector_dbs) > 0
- # Use the actual registered vector_db_id for subsequent operations
- actual_vector_db_id = available_vector_dbs[0]
-
- urls = [
- "memory_optimizations.rst",
- "chat.rst",
- "llama3.rst",
- ]
- documents = [
- Document(
- document_id=f"num-{i}",
- content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
- mime_type="text/plain",
- metadata={},
- )
- for i, url in enumerate(urls)
- ]
-
- client_with_empty_registry.tool_runtime.rag_tool.insert(
- documents=documents,
- vector_db_id=actual_vector_db_id,
- chunk_size_in_tokens=512,
- )
-
- # Query for the name of method
- response1 = client_with_empty_registry.vector_io.query(
- vector_db_id=actual_vector_db_id,
- query="What's the name of the fine-tunning method used?",
- )
- assert_valid_chunk_response(response1)
- assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
-
- # Query for the name of model
- response2 = client_with_empty_registry.vector_io.query(
- vector_db_id=actual_vector_db_id,
- query="Which Llama model is mentioned?",
- )
- assert_valid_chunk_response(response2)
- assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
-
-
-def test_rag_tool_openai_apis(client_with_empty_registry, embedding_model_id, embedding_dimension):
- vector_db_id = "test_openai_vector_db"
-
- client_with_empty_registry.vector_dbs.register(
- vector_db_id=vector_db_id,
- embedding_model=embedding_model_id,
- embedding_dimension=embedding_dimension,
- )
-
- available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
- actual_vector_db_id = available_vector_dbs[0]
-
- # different document formats that should work with OpenAI APIs
- documents = [
- Document(
- document_id="text-doc",
- content="This is a plain text document about machine learning algorithms.",
- metadata={"type": "text", "category": "AI"},
- ),
- Document(
- document_id="url-doc",
- content="https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst",
- mime_type="text/plain",
- metadata={"type": "url", "source": "pytorch"},
- ),
- Document(
- document_id="data-url-doc",
- content="data:text/plain;base64,VGhpcyBpcyBhIGRhdGEgVVJMIGRvY3VtZW50IGFib3V0IGRlZXAgbGVhcm5pbmcu", # "This is a data URL document about deep learning."
- metadata={"type": "data_url", "encoding": "base64"},
- ),
- ]
-
- client_with_empty_registry.tool_runtime.rag_tool.insert(
- documents=documents,
- vector_db_id=actual_vector_db_id,
- chunk_size_in_tokens=256,
- )
-
- files_list = client_with_empty_registry.files.list()
- assert len(files_list.data) >= len(documents), (
- f"Expected at least {len(documents)} files, got {len(files_list.data)}"
- )
-
- vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
- vector_store_id=actual_vector_db_id
- )
- assert len(vector_store_files.data) >= len(documents), f"Expected at least {len(documents)} files in vector store"
-
- response = client_with_empty_registry.tool_runtime.rag_tool.query(
- vector_db_ids=[actual_vector_db_id],
- content="Tell me about machine learning and deep learning",
- )
-
- assert_valid_text_response(response)
- content_text = " ".join([chunk.text for chunk in response.content]).lower()
- assert "machine learning" in content_text or "deep learning" in content_text
-
-
-def test_rag_tool_exception_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
- vector_db_id = "test_exception_handling"
-
- client_with_empty_registry.vector_dbs.register(
- vector_db_id=vector_db_id,
- embedding_model=embedding_model_id,
- embedding_dimension=embedding_dimension,
- )
-
- available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
- actual_vector_db_id = available_vector_dbs[0]
-
- documents = [
- Document(
- document_id="valid-doc",
- content="This is a valid document that should be processed successfully.",
- metadata={"status": "valid"},
- ),
- Document(
- document_id="invalid-url-doc",
- content="https://nonexistent-domain-12345.com/invalid.txt",
- metadata={"status": "invalid_url"},
- ),
- Document(
- document_id="another-valid-doc",
- content="This is another valid document for testing resilience.",
- metadata={"status": "valid"},
- ),
- ]
-
- client_with_empty_registry.tool_runtime.rag_tool.insert(
- documents=documents,
- vector_db_id=actual_vector_db_id,
- chunk_size_in_tokens=256,
- )
-
- response = client_with_empty_registry.tool_runtime.rag_tool.query(
- vector_db_ids=[actual_vector_db_id],
- content="valid document",
- )
-
- assert_valid_text_response(response)
- content_text = " ".join([chunk.text for chunk in response.content]).lower()
- assert "valid document" in content_text
-
-
-def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension):
- providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
- assert len(providers) > 0
-
- vector_db_id = "test_vector_db"
-
- client_with_empty_registry.vector_dbs.register(
- vector_db_id=vector_db_id,
- embedding_model=embedding_model_id,
- embedding_dimension=embedding_dimension,
- )
-
- available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
- # VectorDB is being migrated to VectorStore, so the ID will be different
- # Just check that at least one vector DB was registered
- assert len(available_vector_dbs) > 0
- # Use the actual registered vector_db_id for subsequent operations
- actual_vector_db_id = available_vector_dbs[0]
-
- urls = [
- "memory_optimizations.rst",
- "chat.rst",
- "llama3.rst",
- ]
- documents = [
- Document(
- document_id=f"num-{i}",
- content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
- mime_type="text/plain",
- metadata={"author": "llama", "source": url},
- )
- for i, url in enumerate(urls)
- ]
-
- client_with_empty_registry.tool_runtime.rag_tool.insert(
- documents=documents,
- vector_db_id=actual_vector_db_id,
- chunk_size_in_tokens=512,
- )
-
- response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
- vector_db_ids=[actual_vector_db_id],
- content="What is the name of the method used for fine-tuning?",
- )
- assert_valid_text_response(response_with_metadata)
- assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
-
- response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
- vector_db_ids=[actual_vector_db_id],
- content="What is the name of the method used for fine-tuning?",
- query_config={
- "include_metadata_in_content": True,
- "chunk_template": "Result {index}\nContent: {chunk.content}\n",
- },
- )
- assert_valid_text_response(response_without_metadata)
- assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content)
-
- with pytest.raises((ValueError, BadRequestError)):
- client_with_empty_registry.tool_runtime.rag_tool.query(
- vector_db_ids=[actual_vector_db_id],
- content="What is the name of the method used for fine-tuning?",
- query_config={
- "chunk_template": "This should raise a ValueError because it is missing the proper template variables",
- },
- )
-
-
-def test_rag_tool_query_generation(client_with_empty_registry, embedding_model_id, embedding_dimension):
- vector_db_id = "test_query_generation_db"
-
- client_with_empty_registry.vector_dbs.register(
- vector_db_id=vector_db_id,
- embedding_model=embedding_model_id,
- embedding_dimension=embedding_dimension,
- )
-
- available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
- actual_vector_db_id = available_vector_dbs[0]
-
- documents = [
- Document(
- document_id="ai-doc",
- content="Artificial intelligence and machine learning are transforming technology.",
- metadata={"category": "AI"},
- ),
- Document(
- document_id="banana-doc",
- content="Don't bring a banana to a knife fight.",
- metadata={"category": "wisdom"},
- ),
- ]
-
- client_with_empty_registry.tool_runtime.rag_tool.insert(
- documents=documents,
- vector_db_id=actual_vector_db_id,
- chunk_size_in_tokens=256,
- )
-
- response = client_with_empty_registry.tool_runtime.rag_tool.query(
- vector_db_ids=[actual_vector_db_id],
- content="Tell me about AI",
- )
-
- assert_valid_text_response(response)
- content_text = " ".join([chunk.text for chunk in response.content]).lower()
- assert "artificial intelligence" in content_text or "machine learning" in content_text
-
-
-def test_rag_tool_pdf_data_url_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
- vector_db_id = "test_pdf_data_url_db"
-
- client_with_empty_registry.vector_dbs.register(
- vector_db_id=vector_db_id,
- embedding_model=embedding_model_id,
- embedding_dimension=embedding_dimension,
- )
-
- available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
- actual_vector_db_id = available_vector_dbs[0]
-
- sample_pdf = b"%PDF-1.3\n3 0 obj\n<>\nendobj\n4 0 obj\n<>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<>\nendobj\n5 0 obj\n<>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Llama Stack Developers)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n"
-
- import base64
-
- pdf_base64 = base64.b64encode(sample_pdf).decode("utf-8")
- pdf_data_url = f"data:application/pdf;base64,{pdf_base64}"
-
- documents = [
- Document(
- document_id="test-pdf-data-url",
- content=pdf_data_url,
- metadata={"type": "pdf", "source": "data_url"},
- ),
- ]
-
- client_with_empty_registry.tool_runtime.rag_tool.insert(
- documents=documents,
- vector_db_id=actual_vector_db_id,
- chunk_size_in_tokens=256,
- )
-
- files_list = client_with_empty_registry.files.list()
- assert len(files_list.data) >= 1, "PDF should have been uploaded to Files API"
-
- pdf_file = None
- for file in files_list.data:
- if file.filename and "test-pdf-data-url" in file.filename:
- pdf_file = file
- break
-
- assert pdf_file is not None, "PDF file should be found in Files API"
- assert pdf_file.bytes == len(sample_pdf), f"File size should match original PDF ({len(sample_pdf)} bytes)"
-
- file_content = client_with_empty_registry.files.retrieve_content(pdf_file.id)
- assert file_content.startswith(b"%PDF-"), "Retrieved file should be a valid PDF"
-
- vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
- vector_store_id=actual_vector_db_id
- )
- assert len(vector_store_files.data) >= 1, "PDF should be attached to vector store"
-
- response = client_with_empty_registry.tool_runtime.rag_tool.query(
- vector_db_ids=[actual_vector_db_id],
- content="sample title",
- )
-
- assert_valid_text_response(response)
- content_text = " ".join([chunk.text for chunk in response.content]).lower()
- assert "sample title" in content_text or "title" in content_text
diff --git a/tests/unit/distribution/routing_tables/test_vector_dbs.py b/tests/unit/distribution/routing_tables/test_vector_dbs.py
deleted file mode 100644
index 3444f64c2..000000000
--- a/tests/unit/distribution/routing_tables/test_vector_dbs.py
+++ /dev/null
@@ -1,381 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-# Unit tests for the routing tables vector_dbs
-
-import time
-import uuid
-from unittest.mock import AsyncMock
-
-import pytest
-
-from llama_stack.apis.datatypes import Api
-from llama_stack.apis.models import ModelType
-from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
-from llama_stack.apis.vector_io.vector_io import (
- VectorStoreContent,
- VectorStoreDeleteResponse,
- VectorStoreFileContentsResponse,
- VectorStoreFileCounts,
- VectorStoreFileDeleteResponse,
- VectorStoreFileObject,
- VectorStoreObject,
- VectorStoreSearchResponsePage,
-)
-from llama_stack.core.access_control.datatypes import AccessRule, Scope
-from llama_stack.core.datatypes import User
-from llama_stack.core.request_headers import request_provider_data_context
-from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
-from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
-
-
-class VectorDBImpl(Impl):
- def __init__(self):
- super().__init__(Api.vector_io)
- self.vector_stores = {}
-
- async def register_vector_db(self, vector_db: VectorDB):
- return vector_db
-
- async def unregister_vector_db(self, vector_db_id: str):
- return vector_db_id
-
- async def openai_retrieve_vector_store(self, vector_store_id):
- return VectorStoreObject(
- id=vector_store_id,
- name="Test Store",
- created_at=int(time.time()),
- file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
- )
-
- async def openai_update_vector_store(self, vector_store_id, **kwargs):
- return VectorStoreObject(
- id=vector_store_id,
- name="Updated Store",
- created_at=int(time.time()),
- file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
- )
-
- async def openai_delete_vector_store(self, vector_store_id):
- return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
-
- async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
- return VectorStoreSearchResponsePage(
- object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
- )
-
- async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
- return VectorStoreFileObject(
- id=file_id,
- status="completed",
- chunking_strategy={"type": "auto"},
- created_at=int(time.time()),
- vector_store_id=vector_store_id,
- )
-
- async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
- return [
- VectorStoreFileObject(
- id="1",
- status="completed",
- chunking_strategy={"type": "auto"},
- created_at=int(time.time()),
- vector_store_id=vector_store_id,
- )
- ]
-
- async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
- return VectorStoreFileObject(
- id=file_id,
- status="completed",
- chunking_strategy={"type": "auto"},
- created_at=int(time.time()),
- vector_store_id=vector_store_id,
- )
-
- async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
- return VectorStoreFileContentsResponse(
- file_id=file_id,
- filename="Sample File name",
- attributes={"key": "value"},
- content=[VectorStoreContent(type="text", text="Sample content")],
- )
-
- async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
- return VectorStoreFileObject(
- id=file_id,
- status="completed",
- chunking_strategy={"type": "auto"},
- created_at=int(time.time()),
- vector_store_id=vector_store_id,
- )
-
- async def openai_delete_vector_store_file(self, vector_store_id, file_id):
- return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
-
- async def openai_create_vector_store(
- self,
- name=None,
- embedding_model=None,
- embedding_dimension=None,
- provider_id=None,
- provider_vector_db_id=None,
- **kwargs,
- ):
- vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
- vector_store = VectorStoreObject(
- id=vector_store_id,
- name=name or vector_store_id,
- created_at=int(time.time()),
- file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
- )
- self.vector_stores[vector_store_id] = vector_store
- return vector_store
-
- async def openai_list_vector_stores(self, **kwargs):
- from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
-
- return VectorStoreListResponse(
- data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
- )
-
-
-async def test_vectordbs_routing_table(cached_disk_dist_registry):
- n = 10
- table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
- await table.initialize()
-
- m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
- await m_table.initialize()
- await m_table.register_model(
- model_id="test-model",
- provider_id="test_provider",
- metadata={"embedding_dimension": 128},
- model_type=ModelType.embedding,
- )
-
- # Register multiple vector databases and verify listing
- vdb_dict = {}
- for i in range(n):
- vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
-
- vector_dbs = await table.list_vector_dbs()
-
- assert len(vector_dbs.data) == len(vdb_dict)
- vector_db_ids = {v.identifier for v in vector_dbs.data}
- for k in vdb_dict:
- assert vdb_dict[k].identifier in vector_db_ids
- for k in vdb_dict:
- await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
-
- vector_dbs = await table.list_vector_dbs()
- assert len(vector_dbs.data) == 0
-
-
-async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
- n = 10
- impl = VectorDBImpl()
- table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
- await table.initialize()
-
- m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
- await m_table.initialize()
- await m_table.register_model(
- model_id="test-model",
- provider_id="test_provider",
- metadata={"embedding_dimension": 128},
- model_type=ModelType.embedding,
- )
-
- vdb_dict = {}
- for i in range(n):
- vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
-
- vector_dbs = await table.list_vector_dbs()
- vector_db_ids = {v.identifier for v in vector_dbs.data}
-
- vector_stores = await impl.openai_list_vector_stores()
- vector_store_ids = {v.id for v in vector_stores.data}
-
- assert vector_db_ids == vector_store_ids, (
- f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
- )
-
- for vector_store in vector_stores.data:
- vector_db = await table.get_vector_db(vector_store.id)
- assert vector_store.name == vector_db.vector_db_name, (
- f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
- )
-
- for vector_db_id in vector_db_ids:
- await table.unregister_vector_db(vector_db_id)
-
- assert len((await table.list_vector_dbs()).data) == 0
-
-
-async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
- impl = VectorDBImpl()
- table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
- await table.initialize()
-
- m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
- await m_table.initialize()
- await m_table.register_model(
- model_id="test-model",
- provider_id="test_provider",
- metadata={"embedding_dimension": 128},
- model_type=ModelType.embedding,
- )
-
- user_provided_id = "my-custom-vector-db"
- await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
-
- vector_stores = await impl.openai_list_vector_stores()
- assert len(vector_stores.data) == 1
-
- vector_store = vector_stores.data[0]
-
- assert vector_store.name == user_provided_id
-
- assert vector_store.id.startswith("vs_")
- assert vector_store.id != user_provided_id
-
- vector_dbs = await table.list_vector_dbs()
- assert len(vector_dbs.data) == 1
- assert vector_dbs.data[0].identifier == vector_store.id
-
- await table.unregister_vector_db(vector_store.id)
-
-
-async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
- impl = VectorDBImpl()
- impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
- table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
- m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
- authorized_table = "vs1"
- authorized_team = "team1"
- unauthorized_team = "team2"
-
- await m_table.initialize()
- await m_table.register_model(
- model_id="test-model",
- provider_id="test_provider",
- metadata={"embedding_dimension": 128},
- model_type=ModelType.embedding,
- )
-
- authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
- with request_provider_data_context({}, authorized_user):
- registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
- authorized_table = registered_vdb.identifier # Use the actual generated ID
-
- # Authorized reader
- with request_provider_data_context({}, authorized_user):
- res = await table.openai_retrieve_vector_store(authorized_table)
- assert res == "OK"
-
- # Authorized updater
- impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
- with request_provider_data_context({}, authorized_user):
- res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
- assert res == "UPDATED"
-
- # Unauthorized reader
- unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
- with request_provider_data_context({}, unauthorized_user):
- with pytest.raises(ValueError):
- await table.openai_retrieve_vector_store(authorized_table)
-
- # Unauthorized updater
- with request_provider_data_context({}, unauthorized_user):
- with pytest.raises(ValueError):
- await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
-
- # Authorized deleter
- impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
- with request_provider_data_context({}, authorized_user):
- res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
- assert res == "DELETED"
-
- # Unauthorized deleter
- with request_provider_data_context({}, unauthorized_user):
- with pytest.raises(ValueError):
- await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
-
-
-async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
- impl = VectorDBImpl()
-
- policy = [
- AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
- AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
- ]
-
- table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
- m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
-
- vector_db_id = "vs1"
- file_id = "file-1"
-
- admin_user = User(principal="admin", attributes={"roles": ["admin"]})
- read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
- no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
-
- await m_table.initialize()
- await m_table.register_model(
- model_id="test-model",
- provider_id="test_provider",
- metadata={"embedding_dimension": 128},
- model_type=ModelType.embedding,
- )
-
- with request_provider_data_context({}, admin_user):
- registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
- vector_db_id = registered_vdb.identifier # Use the actual generated ID
-
- read_methods = [
- (table.openai_retrieve_vector_store, (vector_db_id,), {}),
- (table.openai_search_vector_store, (vector_db_id, "query"), {}),
- (table.openai_list_files_in_vector_store, (vector_db_id,), {}),
- (table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
- (table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
- ]
- update_methods = [
- (table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
- (table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
- (table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
- ]
- delete_methods = [
- (table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
- (table.openai_delete_vector_store, (vector_db_id,), {}),
- ]
-
- for user in [admin_user, read_only_user]:
- with request_provider_data_context({}, user):
- for method, args, kwargs in read_methods:
- result = await method(*args, **kwargs)
- assert result is not None, f"Read operation failed with user {user.principal}"
-
- with request_provider_data_context({}, no_access_user):
- for method, args, kwargs in read_methods:
- with pytest.raises(ValueError):
- await method(*args, **kwargs)
-
- with request_provider_data_context({}, admin_user):
- for method, args, kwargs in update_methods:
- result = await method(*args, **kwargs)
- assert result is not None, "Update operation failed with admin user"
-
- with request_provider_data_context({}, admin_user):
- for method, args, kwargs in delete_methods:
- result = await method(*args, **kwargs)
- assert result is not None, "Delete operation failed with admin user"
-
- for user in [read_only_user, no_access_user]:
- with request_provider_data_context({}, user):
- for method, args, kwargs in delete_methods:
- with pytest.raises(ValueError):
- await method(*args, **kwargs)