diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 7e534f995..238a8afa7 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -2987,165 +2987,6 @@ "deprecated": false } }, - "/v1/vector-dbs": { - "get": { - "responses": { - "200": { - "description": "A ListVectorDBsResponse.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListVectorDBsResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "VectorDBs" - ], - "summary": "List all vector databases.", - "description": "List all vector databases.", - "parameters": [], - "deprecated": false - }, - "post": { - "responses": { - "200": { - "description": "A VectorDB.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/VectorDB" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "VectorDBs" - ], - "summary": "Register a vector database.", - "description": "Register a vector database.", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/RegisterVectorDbRequest" - } - } - }, - "required": true - }, - "deprecated": false - } - }, - "/v1/vector-dbs/{vector_db_id}": { - "get": { - "responses": { - "200": { - "description": "A VectorDB.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/VectorDB" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "VectorDBs" - ], - "summary": "Get a vector database by its identifier.", - "description": "Get a vector database by its identifier.", - "parameters": [ - { - "name": "vector_db_id", - "in": "path", - "description": "The identifier of the vector database to get.", - "required": true, - "schema": { - "type": "string" - } - } - ], - "deprecated": false - }, - "delete": { - "responses": { - "200": { - "description": "OK" - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "VectorDBs" - ], - "summary": "Unregister a vector database.", - "description": "Unregister a vector database.", - "parameters": [ - { - "name": "vector_db_id", - "in": "path", - "description": "The identifier of the vector database to unregister.", - "required": true, - "schema": { - "type": "string" - } - } - ], - "deprecated": false - } - }, "/v1/vector-io/insert": { "post": { "responses": { @@ -11791,111 +11632,6 @@ ], "title": "RegisterToolGroupRequest" }, - "VectorDB": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "enum": [ - "model", - "shield", - "vector_db", - "dataset", - "scoring_function", - "benchmark", - "tool", - "tool_group", - "prompt" - ], - "const": "vector_db", - "default": "vector_db", - "description": "Type of resource, always 'vector_db' for vector databases" - }, - "embedding_model": { - "type": "string", - "description": "Name of the embedding model to use for vector generation" - }, - "embedding_dimension": { - "type": "integer", - "description": "Dimension of the embedding vectors" - }, - "vector_db_name": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_id", - "type", - "embedding_model", - "embedding_dimension" - ], - "title": "VectorDB", - "description": "Vector database resource for storing and querying vector embeddings." - }, - "ListVectorDBsResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/VectorDB" - }, - "description": "List of vector databases" - } - }, - "additionalProperties": false, - "required": [ - "data" - ], - "title": "ListVectorDBsResponse", - "description": "Response from listing vector databases." - }, - "RegisterVectorDbRequest": { - "type": "object", - "properties": { - "vector_db_id": { - "type": "string", - "description": "The identifier of the vector database to register." - }, - "embedding_model": { - "type": "string", - "description": "The embedding model to use." - }, - "embedding_dimension": { - "type": "integer", - "description": "The dimension of the embedding model." - }, - "provider_id": { - "type": "string", - "description": "The identifier of the provider." - }, - "vector_db_name": { - "type": "string", - "description": "The name of the vector database." - }, - "provider_vector_db_id": { - "type": "string", - "description": "The identifier of the vector database in the provider." - } - }, - "additionalProperties": false, - "required": [ - "vector_db_id", - "embedding_model" - ], - "title": "RegisterVectorDbRequest" - }, "Chunk": { "type": "object", "properties": { @@ -13371,10 +13107,6 @@ "name": "ToolRuntime", "description": "" }, - { - "name": "VectorDBs", - "description": "" - }, { "name": "VectorIO", "description": "" @@ -13400,7 +13132,6 @@ "Telemetry", "ToolGroups", "ToolRuntime", - "VectorDBs", "VectorIO" ] } diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index bad40c87d..6957afda8 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -2275,120 +2275,6 @@ paths: schema: type: string deprecated: false - /v1/vector-dbs: - get: - responses: - '200': - description: A ListVectorDBsResponse. - content: - application/json: - schema: - $ref: '#/components/schemas/ListVectorDBsResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - VectorDBs - summary: List all vector databases. - description: List all vector databases. - parameters: [] - deprecated: false - post: - responses: - '200': - description: A VectorDB. - content: - application/json: - schema: - $ref: '#/components/schemas/VectorDB' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - VectorDBs - summary: Register a vector database. - description: Register a vector database. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterVectorDbRequest' - required: true - deprecated: false - /v1/vector-dbs/{vector_db_id}: - get: - responses: - '200': - description: A VectorDB. - content: - application/json: - schema: - $ref: '#/components/schemas/VectorDB' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - VectorDBs - summary: Get a vector database by its identifier. - description: Get a vector database by its identifier. - parameters: - - name: vector_db_id - in: path - description: >- - The identifier of the vector database to get. - required: true - schema: - type: string - deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - VectorDBs - summary: Unregister a vector database. - description: Unregister a vector database. - parameters: - - name: vector_db_id - in: path - description: >- - The identifier of the vector database to unregister. - required: true - schema: - type: string - deprecated: false /v1/vector-io/insert: post: responses: @@ -8910,91 +8796,6 @@ components: - toolgroup_id - provider_id title: RegisterToolGroupRequest - VectorDB: - type: object - properties: - identifier: - type: string - provider_resource_id: - type: string - provider_id: - type: string - type: - type: string - enum: - - model - - shield - - vector_db - - dataset - - scoring_function - - benchmark - - tool - - tool_group - - prompt - const: vector_db - default: vector_db - description: >- - Type of resource, always 'vector_db' for vector databases - embedding_model: - type: string - description: >- - Name of the embedding model to use for vector generation - embedding_dimension: - type: integer - description: Dimension of the embedding vectors - vector_db_name: - type: string - additionalProperties: false - required: - - identifier - - provider_id - - type - - embedding_model - - embedding_dimension - title: VectorDB - description: >- - Vector database resource for storing and querying vector embeddings. - ListVectorDBsResponse: - type: object - properties: - data: - type: array - items: - $ref: '#/components/schemas/VectorDB' - description: List of vector databases - additionalProperties: false - required: - - data - title: ListVectorDBsResponse - description: Response from listing vector databases. - RegisterVectorDbRequest: - type: object - properties: - vector_db_id: - type: string - description: >- - The identifier of the vector database to register. - embedding_model: - type: string - description: The embedding model to use. - embedding_dimension: - type: integer - description: The dimension of the embedding model. - provider_id: - type: string - description: The identifier of the provider. - vector_db_name: - type: string - description: The name of the vector database. - provider_vector_db_id: - type: string - description: >- - The identifier of the vector database in the provider. - additionalProperties: false - required: - - vector_db_id - - embedding_model - title: RegisterVectorDbRequest Chunk: type: object properties: @@ -10164,8 +9965,6 @@ tags: description: '' - name: ToolRuntime description: '' - - name: VectorDBs - description: '' - name: VectorIO description: '' x-tagGroups: @@ -10187,5 +9986,4 @@ x-tagGroups: - Telemetry - ToolGroups - ToolRuntime - - VectorDBs - VectorIO diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 36c63367c..5c7a23e7b 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -2987,165 +2987,6 @@ "deprecated": false } }, - "/v1/vector-dbs": { - "get": { - "responses": { - "200": { - "description": "A ListVectorDBsResponse.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListVectorDBsResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "VectorDBs" - ], - "summary": "List all vector databases.", - "description": "List all vector databases.", - "parameters": [], - "deprecated": false - }, - "post": { - "responses": { - "200": { - "description": "A VectorDB.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/VectorDB" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "VectorDBs" - ], - "summary": "Register a vector database.", - "description": "Register a vector database.", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/RegisterVectorDbRequest" - } - } - }, - "required": true - }, - "deprecated": false - } - }, - "/v1/vector-dbs/{vector_db_id}": { - "get": { - "responses": { - "200": { - "description": "A VectorDB.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/VectorDB" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "VectorDBs" - ], - "summary": "Get a vector database by its identifier.", - "description": "Get a vector database by its identifier.", - "parameters": [ - { - "name": "vector_db_id", - "in": "path", - "description": "The identifier of the vector database to get.", - "required": true, - "schema": { - "type": "string" - } - } - ], - "deprecated": false - }, - "delete": { - "responses": { - "200": { - "description": "OK" - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "VectorDBs" - ], - "summary": "Unregister a vector database.", - "description": "Unregister a vector database.", - "parameters": [ - { - "name": "vector_db_id", - "in": "path", - "description": "The identifier of the vector database to unregister.", - "required": true, - "schema": { - "type": "string" - } - } - ], - "deprecated": false - } - }, "/v1/vector-io/insert": { "post": { "responses": { @@ -13800,111 +13641,6 @@ ], "title": "RegisterToolGroupRequest" }, - "VectorDB": { - "type": "object", - "properties": { - "identifier": { - "type": "string" - }, - "provider_resource_id": { - "type": "string" - }, - "provider_id": { - "type": "string" - }, - "type": { - "type": "string", - "enum": [ - "model", - "shield", - "vector_db", - "dataset", - "scoring_function", - "benchmark", - "tool", - "tool_group", - "prompt" - ], - "const": "vector_db", - "default": "vector_db", - "description": "Type of resource, always 'vector_db' for vector databases" - }, - "embedding_model": { - "type": "string", - "description": "Name of the embedding model to use for vector generation" - }, - "embedding_dimension": { - "type": "integer", - "description": "Dimension of the embedding vectors" - }, - "vector_db_name": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "identifier", - "provider_id", - "type", - "embedding_model", - "embedding_dimension" - ], - "title": "VectorDB", - "description": "Vector database resource for storing and querying vector embeddings." - }, - "ListVectorDBsResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/VectorDB" - }, - "description": "List of vector databases" - } - }, - "additionalProperties": false, - "required": [ - "data" - ], - "title": "ListVectorDBsResponse", - "description": "Response from listing vector databases." - }, - "RegisterVectorDbRequest": { - "type": "object", - "properties": { - "vector_db_id": { - "type": "string", - "description": "The identifier of the vector database to register." - }, - "embedding_model": { - "type": "string", - "description": "The embedding model to use." - }, - "embedding_dimension": { - "type": "integer", - "description": "The dimension of the embedding model." - }, - "provider_id": { - "type": "string", - "description": "The identifier of the provider." - }, - "vector_db_name": { - "type": "string", - "description": "The name of the vector database." - }, - "provider_vector_db_id": { - "type": "string", - "description": "The identifier of the vector database in the provider." - } - }, - "additionalProperties": false, - "required": [ - "vector_db_id", - "embedding_model" - ], - "title": "RegisterVectorDbRequest" - }, "Chunk": { "type": "object", "properties": { @@ -18948,10 +18684,6 @@ "name": "ToolRuntime", "description": "" }, - { - "name": "VectorDBs", - "description": "" - }, { "name": "VectorIO", "description": "" @@ -18982,7 +18714,6 @@ "Telemetry", "ToolGroups", "ToolRuntime", - "VectorDBs", "VectorIO" ] } diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 4475cc8f0..45a76613e 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -2278,120 +2278,6 @@ paths: schema: type: string deprecated: false - /v1/vector-dbs: - get: - responses: - '200': - description: A ListVectorDBsResponse. - content: - application/json: - schema: - $ref: '#/components/schemas/ListVectorDBsResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - VectorDBs - summary: List all vector databases. - description: List all vector databases. - parameters: [] - deprecated: false - post: - responses: - '200': - description: A VectorDB. - content: - application/json: - schema: - $ref: '#/components/schemas/VectorDB' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - VectorDBs - summary: Register a vector database. - description: Register a vector database. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/RegisterVectorDbRequest' - required: true - deprecated: false - /v1/vector-dbs/{vector_db_id}: - get: - responses: - '200': - description: A VectorDB. - content: - application/json: - schema: - $ref: '#/components/schemas/VectorDB' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - VectorDBs - summary: Get a vector database by its identifier. - description: Get a vector database by its identifier. - parameters: - - name: vector_db_id - in: path - description: >- - The identifier of the vector database to get. - required: true - schema: - type: string - deprecated: false - delete: - responses: - '200': - description: OK - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - VectorDBs - summary: Unregister a vector database. - description: Unregister a vector database. - parameters: - - name: vector_db_id - in: path - description: >- - The identifier of the vector database to unregister. - required: true - schema: - type: string - deprecated: false /v1/vector-io/insert: post: responses: @@ -10355,91 +10241,6 @@ components: - toolgroup_id - provider_id title: RegisterToolGroupRequest - VectorDB: - type: object - properties: - identifier: - type: string - provider_resource_id: - type: string - provider_id: - type: string - type: - type: string - enum: - - model - - shield - - vector_db - - dataset - - scoring_function - - benchmark - - tool - - tool_group - - prompt - const: vector_db - default: vector_db - description: >- - Type of resource, always 'vector_db' for vector databases - embedding_model: - type: string - description: >- - Name of the embedding model to use for vector generation - embedding_dimension: - type: integer - description: Dimension of the embedding vectors - vector_db_name: - type: string - additionalProperties: false - required: - - identifier - - provider_id - - type - - embedding_model - - embedding_dimension - title: VectorDB - description: >- - Vector database resource for storing and querying vector embeddings. - ListVectorDBsResponse: - type: object - properties: - data: - type: array - items: - $ref: '#/components/schemas/VectorDB' - description: List of vector databases - additionalProperties: false - required: - - data - title: ListVectorDBsResponse - description: Response from listing vector databases. - RegisterVectorDbRequest: - type: object - properties: - vector_db_id: - type: string - description: >- - The identifier of the vector database to register. - embedding_model: - type: string - description: The embedding model to use. - embedding_dimension: - type: integer - description: The dimension of the embedding model. - provider_id: - type: string - description: The identifier of the provider. - vector_db_name: - type: string - description: The name of the vector database. - provider_vector_db_id: - type: string - description: >- - The identifier of the vector database in the provider. - additionalProperties: false - required: - - vector_db_id - - embedding_model - title: RegisterVectorDbRequest Chunk: type: object properties: @@ -14212,8 +14013,6 @@ tags: description: '' - name: ToolRuntime description: '' - - name: VectorDBs - description: '' - name: VectorIO description: '' x-tagGroups: @@ -14240,5 +14039,4 @@ x-tagGroups: - Telemetry - ToolGroups - ToolRuntime - - VectorDBs - VectorIO diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 521d129c6..53bf181e9 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -4,14 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Literal, Protocol, runtime_checkable +from typing import Literal from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType -from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -from llama_stack.schema_utils import json_schema_type, webmethod +from llama_stack.schema_utils import json_schema_type @json_schema_type @@ -61,57 +59,3 @@ class ListVectorDBsResponse(BaseModel): """ data: list[VectorDB] - - -@runtime_checkable -@trace_protocol -class VectorDBs(Protocol): - @webmethod(route="/vector-dbs", method="GET", level=LLAMA_STACK_API_V1) - async def list_vector_dbs(self) -> ListVectorDBsResponse: - """List all vector databases. - - :returns: A ListVectorDBsResponse. - """ - ... - - @webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET", level=LLAMA_STACK_API_V1) - async def get_vector_db( - self, - vector_db_id: str, - ) -> VectorDB: - """Get a vector database by its identifier. - - :param vector_db_id: The identifier of the vector database to get. - :returns: A VectorDB. - """ - ... - - @webmethod(route="/vector-dbs", method="POST", level=LLAMA_STACK_API_V1) - async def register_vector_db( - self, - vector_db_id: str, - embedding_model: str, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - vector_db_name: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorDB: - """Register a vector database. - - :param vector_db_id: The identifier of the vector database to register. - :param embedding_model: The embedding model to use. - :param embedding_dimension: The dimension of the embedding model. - :param provider_id: The identifier of the provider. - :param vector_db_name: The name of the vector database. - :param provider_vector_db_id: The identifier of the vector database in the provider. - :returns: A VectorDB. - """ - ... - - @webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE", level=LLAMA_STACK_API_V1) - async def unregister_vector_db(self, vector_db_id: str) -> None: - """Unregister a vector database. - - :param vector_db_id: The identifier of the vector database to unregister. - """ - ... diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 749253865..c8011307d 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -28,7 +28,6 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime -from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA from llama_stack.core.client import get_client_impl @@ -81,7 +80,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> Api.inspect: Inspect, Api.batches: Batches, Api.vector_io: VectorIO, - Api.vector_dbs: VectorDBs, Api.models: Models, Api.safety: Safety, Api.shields: Shields, @@ -125,7 +123,7 @@ def additional_protocols_map() -> dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups), - Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), + Api.vector_io: (VectorDBsProtocolPrivate,), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), Api.scoring: ( diff --git a/llama_stack/core/routers/__init__.py b/llama_stack/core/routers/__init__.py index f129f8ede..a1a8b0144 100644 --- a/llama_stack/core/routers/__init__.py +++ b/llama_stack/core/routers/__init__.py @@ -26,10 +26,8 @@ async def get_routing_table_impl( from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable from ..routing_tables.shields import ShieldsRoutingTable from ..routing_tables.toolgroups import ToolGroupsRoutingTable - from ..routing_tables.vector_dbs import VectorDBsRoutingTable api_to_tables = { - "vector_dbs": VectorDBsRoutingTable, "models": ModelsRoutingTable, "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 0800b909b..0b5aa7843 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -134,15 +134,12 @@ class CommonRoutingTableImpl(RoutingTable): from .scoring_functions import ScoringFunctionsRoutingTable from .shields import ShieldsRoutingTable from .toolgroups import ToolGroupsRoutingTable - from .vector_dbs import VectorDBsRoutingTable def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") elif isinstance(self, ShieldsRoutingTable): return ("Safety", "shield") - elif isinstance(self, VectorDBsRoutingTable): - return ("VectorIO", "vector_db") elif isinstance(self, DatasetsRoutingTable): return ("DatasetIO", "dataset") elif isinstance(self, ScoringFunctionsRoutingTable): diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py deleted file mode 100644 index 995e0351d..000000000 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from pydantic import TypeAdapter - -from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError -from llama_stack.apis.models import ModelType -from llama_stack.apis.resource import ResourceType -from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs -from llama_stack.apis.vector_io.vector_io import ( - SearchRankingOptions, - VectorStoreChunkingStrategy, - VectorStoreDeleteResponse, - VectorStoreFileContentsResponse, - VectorStoreFileDeleteResponse, - VectorStoreFileObject, - VectorStoreFileStatus, - VectorStoreObject, - VectorStoreSearchResponsePage, -) -from llama_stack.core.datatypes import ( - VectorDBWithOwner, -) -from llama_stack.log import get_logger - -from .common import CommonRoutingTableImpl, lookup_model - -logger = get_logger(name=__name__, category="core::routing_tables") - - -class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): - async def list_vector_dbs(self) -> ListVectorDBsResponse: - return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) - - async def get_vector_db(self, vector_db_id: str) -> VectorDB: - vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) - if vector_db is None: - raise VectorStoreNotFoundError(vector_db_id) - return vector_db - - async def register_vector_db( - self, - vector_db_id: str, - embedding_model: str, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - vector_db_name: str | None = None, - ) -> VectorDB: - if provider_id is None: - if len(self.impls_by_provider_id) > 0: - provider_id = list(self.impls_by_provider_id.keys())[0] - if len(self.impls_by_provider_id) > 1: - logger.warning( - f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." - ) - else: - raise ValueError("No provider available. Please configure a vector_io provider.") - model = await lookup_model(self, embedding_model) - if model is None: - raise ModelNotFoundError(embedding_model) - if model.model_type != ModelType.embedding: - raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) - if "embedding_dimension" not in model.metadata: - raise ValueError(f"Model {embedding_model} does not have an embedding dimension") - - provider = self.impls_by_provider_id[provider_id] - vector_store = await provider.openai_create_vector_store( - name=vector_db_name or vector_db_id, - embedding_model=embedding_model, - embedding_dimension=model.metadata["embedding_dimension"], - provider_id=provider_id, - provider_vector_db_id=provider_vector_db_id, - ) - - vector_store_id = vector_store.id - actual_provider_vector_db_id = provider_vector_db_id or vector_store_id - logger.warning( - f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name" - ) - - vector_db_data = { - "identifier": vector_store_id, - "type": ResourceType.vector_db.value, - "provider_id": provider_id, - "provider_resource_id": actual_provider_vector_db_id, - "embedding_model": embedding_model, - "embedding_dimension": model.metadata["embedding_dimension"], - "vector_db_name": vector_store.name, - } - vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) - await self.register_object(vector_db) - return vector_db - - async def unregister_vector_db(self, vector_db_id: str) -> None: - existing_vector_db = await self.get_vector_db(vector_db_id) - await self.unregister_object(existing_vector_db) - - async def openai_retrieve_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreObject: - await self.assert_action_allowed("read", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store(vector_store_id) - - async def openai_update_vector_store( - self, - vector_store_id: str, - name: str | None = None, - expires_after: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> VectorStoreObject: - await self.assert_action_allowed("update", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store( - vector_store_id=vector_store_id, - name=name, - expires_after=expires_after, - metadata=metadata, - ) - - async def openai_delete_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreDeleteResponse: - await self.assert_action_allowed("delete", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - result = await provider.openai_delete_vector_store(vector_store_id) - await self.unregister_vector_db(vector_store_id) - return result - - async def openai_search_vector_store( - self, - vector_store_id: str, - query: str | list[str], - filters: dict[str, Any] | None = None, - max_num_results: int | None = 10, - ranking_options: SearchRankingOptions | None = None, - rewrite_query: bool | None = False, - search_mode: str | None = "vector", - ) -> VectorStoreSearchResponsePage: - await self.assert_action_allowed("read", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_search_vector_store( - vector_store_id=vector_store_id, - query=query, - filters=filters, - max_num_results=max_num_results, - ranking_options=ranking_options, - rewrite_query=rewrite_query, - search_mode=search_mode, - ) - - async def openai_attach_file_to_vector_store( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any] | None = None, - chunking_strategy: VectorStoreChunkingStrategy | None = None, - ) -> VectorStoreFileObject: - await self.assert_action_allowed("update", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_attach_file_to_vector_store( - vector_store_id=vector_store_id, - file_id=file_id, - attributes=attributes, - chunking_strategy=chunking_strategy, - ) - - async def openai_list_files_in_vector_store( - self, - vector_store_id: str, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - filter: VectorStoreFileStatus | None = None, - ) -> list[VectorStoreFileObject]: - await self.assert_action_allowed("read", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_list_files_in_vector_store( - vector_store_id=vector_store_id, - limit=limit, - order=order, - after=after, - before=before, - filter=filter, - ) - - async def openai_retrieve_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: - await self.assert_action_allowed("read", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file( - vector_store_id=vector_store_id, - file_id=file_id, - ) - - async def openai_retrieve_vector_store_file_contents( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileContentsResponse: - await self.assert_action_allowed("read", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file_contents( - vector_store_id=vector_store_id, - file_id=file_id, - ) - - async def openai_update_vector_store_file( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any], - ) -> VectorStoreFileObject: - await self.assert_action_allowed("update", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store_file( - vector_store_id=vector_store_id, - file_id=file_id, - attributes=attributes, - ) - - async def openai_delete_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileDeleteResponse: - await self.assert_action_allowed("delete", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_delete_vector_store_file( - vector_store_id=vector_store_id, - file_id=file_id, - ) - - async def openai_create_vector_store_file_batch( - self, - vector_store_id: str, - file_ids: list[str], - attributes: dict[str, Any] | None = None, - chunking_strategy: Any | None = None, - ): - await self.assert_action_allowed("update", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_create_vector_store_file_batch( - vector_store_id=vector_store_id, - file_ids=file_ids, - attributes=attributes, - chunking_strategy=chunking_strategy, - ) - - async def openai_retrieve_vector_store_file_batch( - self, - batch_id: str, - vector_store_id: str, - ): - await self.assert_action_allowed("read", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file_batch( - batch_id=batch_id, - vector_store_id=vector_store_id, - ) - - async def openai_list_files_in_vector_store_file_batch( - self, - batch_id: str, - vector_store_id: str, - after: str | None = None, - before: str | None = None, - filter: str | None = None, - limit: int | None = 20, - order: str | None = "desc", - ): - await self.assert_action_allowed("read", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_list_files_in_vector_store_file_batch( - batch_id=batch_id, - vector_store_id=vector_store_id, - after=after, - before=before, - filter=filter, - limit=limit, - order=order, - ) - - async def openai_cancel_vector_store_file_batch( - self, - batch_id: str, - vector_store_id: str, - ): - await self.assert_action_allowed("update", "vector_db", vector_store_id) - provider = await self.get_provider_impl(vector_store_id) - return await provider.openai_cancel_vector_store_file_batch( - batch_id=batch_id, - vector_store_id=vector_store_id, - ) diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 2eab9344f..f161ac358 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -33,7 +33,6 @@ from llama_stack.apis.shields import Shields from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime -from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl from llama_stack.core.datatypes import Provider, StackRunConfig @@ -53,7 +52,6 @@ logger = get_logger(name=__name__, category="core") class LlamaStack( Providers, - VectorDBs, Inference, Agents, Safety, @@ -83,7 +81,6 @@ class LlamaStack( RESOURCES = [ ("models", Api.models, "register_model", "list_models"), ("shields", Api.shields, "register_shield", "list_shields"), - ("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"), ("datasets", Api.datasets, "register_dataset", "list_datasets"), ( "scoring_fns", diff --git a/llama_stack/core/ui/page/distribution/resources.py b/llama_stack/core/ui/page/distribution/resources.py index c56fcfff3..6e7122ceb 100644 --- a/llama_stack/core/ui/page/distribution/resources.py +++ b/llama_stack/core/ui/page/distribution/resources.py @@ -11,19 +11,17 @@ from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks from llama_stack.core.ui.page.distribution.models import models from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions from llama_stack.core.ui.page.distribution.shields import shields -from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs def resources_page(): options = [ "Models", - "Vector Databases", "Shields", "Scoring Functions", "Datasets", "Benchmarks", ] - icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"] + icons = ["magic", "shield", "file-bar-graph", "database", "list-task"] selected_resource = option_menu( None, options, @@ -37,8 +35,6 @@ def resources_page(): ) if selected_resource == "Benchmarks": benchmarks() - elif selected_resource == "Vector Databases": - vector_dbs() elif selected_resource == "Datasets": datasets() elif selected_resource == "Models": diff --git a/llama_stack/core/ui/page/distribution/vector_dbs.py b/llama_stack/core/ui/page/distribution/vector_dbs.py deleted file mode 100644 index e81077d2a..000000000 --- a/llama_stack/core/ui/page/distribution/vector_dbs.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import streamlit as st - -from llama_stack.core.ui.modules.api import llama_stack_api - - -def vector_dbs(): - st.header("Vector Databases") - vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()} - - if len(vector_dbs_info) > 0: - selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys())) - st.json(vector_dbs_info[selected_vector_db]) - else: - st.info("No vector databases found") diff --git a/llama_stack/core/ui/page/playground/rag.py b/llama_stack/core/ui/page/playground/rag.py deleted file mode 100644 index 2ffae1c33..000000000 --- a/llama_stack/core/ui/page/playground/rag.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import uuid - -import streamlit as st -from llama_stack_client import Agent, AgentEventLogger, RAGDocument - -from llama_stack.apis.common.content_types import ToolCallDelta -from llama_stack.core.ui.modules.api import llama_stack_api -from llama_stack.core.ui.modules.utils import data_url_from_file - - -def rag_chat_page(): - st.title("🦙 RAG") - - def reset_agent_and_chat(): - st.session_state.clear() - st.cache_resource.clear() - - def should_disable_input(): - return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0 - - def log_message(message): - with st.chat_message(message["role"]): - if "tool_output" in message and message["tool_output"]: - with st.expander(label="Tool Output", expanded=False, icon="🛠"): - st.write(message["tool_output"]) - st.markdown(message["content"]) - - with st.sidebar: - # File/Directory Upload Section - st.subheader("Upload Documents", divider=True) - uploaded_files = st.file_uploader( - "Upload file(s) or directory", - accept_multiple_files=True, - type=["txt", "pdf", "doc", "docx"], # Add more file types as needed - ) - # Process uploaded files - if uploaded_files: - st.success(f"Successfully uploaded {len(uploaded_files)} files") - # Add memory bank name input field - vector_db_name = st.text_input( - "Document Collection Name", - value="rag_vector_db", - help="Enter a unique identifier for this document collection", - ) - if st.button("Create Document Collection"): - documents = [ - RAGDocument( - document_id=uploaded_file.name, - content=data_url_from_file(uploaded_file), - ) - for i, uploaded_file in enumerate(uploaded_files) - ] - - providers = llama_stack_api.client.providers.list() - vector_io_provider = None - - for x in providers: - if x.api == "vector_io": - vector_io_provider = x.provider_id - - llama_stack_api.client.vector_dbs.register( - vector_db_id=vector_db_name, # Use the user-provided name - embedding_dimension=384, - embedding_model="all-MiniLM-L6-v2", - provider_id=vector_io_provider, - ) - - # insert documents using the custom vector db name - llama_stack_api.client.tool_runtime.rag_tool.insert( - vector_db_id=vector_db_name, # Use the user-provided name - documents=documents, - chunk_size_in_tokens=512, - ) - st.success("Vector database created successfully!") - - st.subheader("RAG Parameters", divider=True) - - rag_mode = st.radio( - "RAG mode", - ["Direct", "Agent-based"], - captions=[ - "RAG is performed by directly retrieving the information and augmenting the user query", - "RAG is performed by an agent activating a dedicated knowledge search tool.", - ], - on_change=reset_agent_and_chat, - disabled=should_disable_input(), - ) - - # select memory banks - vector_dbs = llama_stack_api.client.vector_dbs.list() - vector_dbs = [vector_db.identifier for vector_db in vector_dbs] - selected_vector_dbs = st.multiselect( - label="Select Document Collections to use in RAG queries", - options=vector_dbs, - on_change=reset_agent_and_chat, - disabled=should_disable_input(), - ) - - st.subheader("Inference Parameters", divider=True) - available_models = llama_stack_api.client.models.list() - available_models = [model.identifier for model in available_models if model.model_type == "llm"] - selected_model = st.selectbox( - label="Choose a model", - options=available_models, - index=0, - on_change=reset_agent_and_chat, - disabled=should_disable_input(), - ) - system_prompt = st.text_area( - "System Prompt", - value="You are a helpful assistant. ", - help="Initial instructions given to the AI to set its behavior and context", - on_change=reset_agent_and_chat, - disabled=should_disable_input(), - ) - temperature = st.slider( - "Temperature", - min_value=0.0, - max_value=1.0, - value=0.0, - step=0.1, - help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable", - on_change=reset_agent_and_chat, - disabled=should_disable_input(), - ) - - top_p = st.slider( - "Top P", - min_value=0.0, - max_value=1.0, - value=0.95, - step=0.1, - on_change=reset_agent_and_chat, - disabled=should_disable_input(), - ) - - # Add clear chat button to sidebar - if st.button("Clear Chat", use_container_width=True): - reset_agent_and_chat() - st.rerun() - - # Chat Interface - if "messages" not in st.session_state: - st.session_state.messages = [] - if "displayed_messages" not in st.session_state: - st.session_state.displayed_messages = [] - - # Display chat history - for message in st.session_state.displayed_messages: - log_message(message) - - if temperature > 0.0: - strategy = { - "type": "top_p", - "temperature": temperature, - "top_p": top_p, - } - else: - strategy = {"type": "greedy"} - - @st.cache_resource - def create_agent(): - return Agent( - llama_stack_api.client, - model=selected_model, - instructions=system_prompt, - sampling_params={ - "strategy": strategy, - }, - tools=[ - dict( - name="builtin::rag/knowledge_search", - args={ - "vector_db_ids": list(selected_vector_dbs), - }, - ) - ], - ) - - if rag_mode == "Agent-based": - agent = create_agent() - if "agent_session_id" not in st.session_state: - st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}") - - session_id = st.session_state["agent_session_id"] - - def agent_process_prompt(prompt): - # Add user message to chat history - st.session_state.messages.append({"role": "user", "content": prompt}) - - # Send the prompt to the agent - response = agent.create_turn( - messages=[ - { - "role": "user", - "content": prompt, - } - ], - session_id=session_id, - ) - - # Display assistant response - with st.chat_message("assistant"): - retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠") - message_placeholder = st.empty() - full_response = "" - retrieval_response = "" - for log in AgentEventLogger().log(response): - log.print() - if log.role == "tool_execution": - retrieval_response += log.content.replace("====", "").strip() - retrieval_message_placeholder.write(retrieval_response) - else: - full_response += log.content - message_placeholder.markdown(full_response + "▌") - message_placeholder.markdown(full_response) - - st.session_state.messages.append({"role": "assistant", "content": full_response}) - st.session_state.displayed_messages.append( - {"role": "assistant", "content": full_response, "tool_output": retrieval_response} - ) - - def direct_process_prompt(prompt): - # Add the system prompt in the beginning of the conversation - if len(st.session_state.messages) == 0: - st.session_state.messages.append({"role": "system", "content": system_prompt}) - - # Query the vector DB - rag_response = llama_stack_api.client.tool_runtime.rag_tool.query( - content=prompt, vector_db_ids=list(selected_vector_dbs) - ) - prompt_context = rag_response.content - - with st.chat_message("assistant"): - with st.expander(label="Retrieval Output", expanded=False): - st.write(prompt_context) - - retrieval_message_placeholder = st.empty() - message_placeholder = st.empty() - full_response = "" - retrieval_response = "" - - # Construct the extended prompt - extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}" - - # Run inference directly - st.session_state.messages.append({"role": "user", "content": extended_prompt}) - response = llama_stack_api.client.inference.chat_completion( - messages=st.session_state.messages, - model_id=selected_model, - sampling_params={ - "strategy": strategy, - }, - stream=True, - ) - - # Display assistant response - for chunk in response: - response_delta = chunk.event.delta - if isinstance(response_delta, ToolCallDelta): - retrieval_response += response_delta.tool_call.replace("====", "").strip() - retrieval_message_placeholder.info(retrieval_response) - else: - full_response += chunk.event.delta.text - message_placeholder.markdown(full_response + "▌") - message_placeholder.markdown(full_response) - - response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"} - st.session_state.messages.append(response_dict) - st.session_state.displayed_messages.append(response_dict) - - # Chat input - if prompt := st.chat_input("Ask a question about your documents"): - # Add user message to chat history - st.session_state.displayed_messages.append({"role": "user", "content": prompt}) - - # Display user message - with st.chat_message("user"): - st.markdown(prompt) - - # store the prompt to process it after page refresh - st.session_state.prompt = prompt - - # force page refresh to disable the settings widgets - st.rerun() - - if "prompt" in st.session_state and st.session_state.prompt is not None: - if rag_mode == "Agent-based": - agent_process_prompt(st.session_state.prompt) - else: # rag_mode == "Direct" - direct_process_prompt(st.session_state.prompt) - st.session_state.prompt = None - - -rag_chat_page() diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 07ba7bb01..3542facef 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -8,7 +8,6 @@ from typing import Any from uuid import uuid4 import pytest -import requests from llama_stack_client import Agent, AgentEventLogger, Document from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig @@ -443,118 +442,6 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice): return [step for step in response.steps if step.step_type == "tool_execution"] -@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]) -def test_rag_agent(llama_stack_client, agent_config, rag_tool_name): - urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] - documents = [ - Document( - document_id=f"num-{i}", - content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", - mime_type="text/plain", - metadata={}, - ) - for i, url in enumerate(urls) - ] - vector_db_id = f"test-vector-db-{uuid4()}" - llama_stack_client.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", - embedding_dimension=384, - ) - llama_stack_client.tool_runtime.rag_tool.insert( - documents=documents, - vector_db_id=vector_db_id, - # small chunks help to get specific info out of the docs - chunk_size_in_tokens=256, - ) - agent_config = { - **agent_config, - "tools": [ - dict( - name=rag_tool_name, - args={ - "vector_db_ids": [vector_db_id], - }, - ) - ], - } - rag_agent = Agent(llama_stack_client, **agent_config) - session_id = rag_agent.create_session(f"test-session-{uuid4()}") - user_prompts = [ - ( - "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", - "grouped", - ), - ] - for prompt, expected_kw in user_prompts: - response = rag_agent.create_turn( - messages=[{"role": "user", "content": prompt}], - session_id=session_id, - stream=False, - ) - # rag is called - tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") - assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search" - # document ids are present in metadata - assert all( - doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"] - ) - if expected_kw: - assert expected_kw in response.output_message.content.lower() - - -def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety): - urls = ["llama3.rst", "lora_finetune.rst"] - documents = [ - # passign as url - Document( - document_id="num-0", - content={ - "type": "url", - "uri": f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{urls[0]}", - }, - mime_type="text/plain", - metadata={}, - ), - # passing as str - Document( - document_id="num-1", - content=requests.get( - f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{urls[1]}" - ).text[:500], - mime_type="text/plain", - metadata={}, - ), - ] - rag_agent = Agent(llama_stack_client, **agent_config_without_safety) - session_id = rag_agent.create_session(f"test-session-{uuid4()}") - user_prompts = [ - ( - "I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", - documents, - ), - ( - "Tell me how to use LoRA in 100 words or less", - None, - ), - ] - - for prompt in user_prompts: - response = rag_agent.create_turn( - messages=[ - { - "role": "user", - "content": prompt[0], - } - ], - documents=prompt[1], - session_id=session_id, - stream=False, - ) - - assert "lora" in response.output_message.content.lower() - - @pytest.mark.parametrize( "client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)], diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py deleted file mode 100644 index b78c39af8..000000000 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -from llama_stack_client import BadRequestError -from llama_stack_client.types import Document - - -@pytest.fixture(scope="function") -def client_with_empty_registry(client_with_models): - def clear_registry(): - vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()] - for vector_db_id in vector_dbs: - client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id) - - clear_registry() - - try: - client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime") - except Exception: - pass - - yield client_with_models - - clear_registry() - - -@pytest.fixture(scope="session") -def sample_documents(): - return [ - Document( - document_id="test-doc-1", - content="Python is a high-level programming language.", - metadata={"category": "programming", "difficulty": "beginner"}, - ), - Document( - document_id="test-doc-2", - content="Machine learning is a subset of artificial intelligence.", - metadata={"category": "AI", "difficulty": "advanced"}, - ), - Document( - document_id="test-doc-3", - content="Data structures are fundamental to computer science.", - metadata={"category": "computer science", "difficulty": "intermediate"}, - ), - Document( - document_id="test-doc-4", - content="Neural networks are inspired by biological neural networks.", - metadata={"category": "AI", "difficulty": "advanced"}, - ), - ] - - -def assert_valid_chunk_response(response): - assert len(response.chunks) > 0 - assert len(response.scores) > 0 - assert len(response.chunks) == len(response.scores) - for chunk in response.chunks: - assert isinstance(chunk.content, str) - - -def assert_valid_text_response(response): - assert len(response.content) > 0 - assert all(isinstance(chunk.text, str) for chunk in response.content) - - -def test_vector_db_insert_inline_and_query( - client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension -): - vector_db_name = "test_vector_db" - vector_db = client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_name, - embedding_model=embedding_model_id, - embedding_dimension=embedding_dimension, - ) - vector_db_id = vector_db.identifier - - client_with_empty_registry.tool_runtime.rag_tool.insert( - documents=sample_documents, - chunk_size_in_tokens=512, - vector_db_id=vector_db_id, - ) - - # Query with a direct match - query1 = "programming language" - response1 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, - query=query1, - ) - assert_valid_chunk_response(response1) - assert any("Python" in chunk.content for chunk in response1.chunks) - - # Query with semantic similarity - query2 = "AI and brain-inspired computing" - response2 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, - query=query2, - ) - assert_valid_chunk_response(response2) - assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks) - - # Query with limit on number of results (max_chunks=2) - query3 = "computer" - response3 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, - query=query3, - params={"max_chunks": 2}, - ) - assert_valid_chunk_response(response3) - assert len(response3.chunks) <= 2 - - # Query with threshold on similarity score - query4 = "computer" - response4 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, - query=query4, - params={"score_threshold": 0.01}, - ) - assert_valid_chunk_response(response4) - assert all(score >= 0.01 for score in response4.scores) - - -def test_vector_db_insert_from_url_and_query( - client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension -): - providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] - assert len(providers) > 0 - - vector_db_id = "test_vector_db" - - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model=embedding_model_id, - embedding_dimension=embedding_dimension, - ) - - # list to check memory bank is successfully registered - available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - # VectorDB is being migrated to VectorStore, so the ID will be different - # Just check that at least one vector DB was registered - assert len(available_vector_dbs) > 0 - # Use the actual registered vector_db_id for subsequent operations - actual_vector_db_id = available_vector_dbs[0] - - urls = [ - "memory_optimizations.rst", - "chat.rst", - "llama3.rst", - ] - documents = [ - Document( - document_id=f"num-{i}", - content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", - mime_type="text/plain", - metadata={}, - ) - for i, url in enumerate(urls) - ] - - client_with_empty_registry.tool_runtime.rag_tool.insert( - documents=documents, - vector_db_id=actual_vector_db_id, - chunk_size_in_tokens=512, - ) - - # Query for the name of method - response1 = client_with_empty_registry.vector_io.query( - vector_db_id=actual_vector_db_id, - query="What's the name of the fine-tunning method used?", - ) - assert_valid_chunk_response(response1) - assert any("lora" in chunk.content.lower() for chunk in response1.chunks) - - # Query for the name of model - response2 = client_with_empty_registry.vector_io.query( - vector_db_id=actual_vector_db_id, - query="Which Llama model is mentioned?", - ) - assert_valid_chunk_response(response2) - assert any("llama2" in chunk.content.lower() for chunk in response2.chunks) - - -def test_rag_tool_openai_apis(client_with_empty_registry, embedding_model_id, embedding_dimension): - vector_db_id = "test_openai_vector_db" - - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model=embedding_model_id, - embedding_dimension=embedding_dimension, - ) - - available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - actual_vector_db_id = available_vector_dbs[0] - - # different document formats that should work with OpenAI APIs - documents = [ - Document( - document_id="text-doc", - content="This is a plain text document about machine learning algorithms.", - metadata={"type": "text", "category": "AI"}, - ), - Document( - document_id="url-doc", - content="https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst", - mime_type="text/plain", - metadata={"type": "url", "source": "pytorch"}, - ), - Document( - document_id="data-url-doc", - content="data:text/plain;base64,VGhpcyBpcyBhIGRhdGEgVVJMIGRvY3VtZW50IGFib3V0IGRlZXAgbGVhcm5pbmcu", # "This is a data URL document about deep learning." - metadata={"type": "data_url", "encoding": "base64"}, - ), - ] - - client_with_empty_registry.tool_runtime.rag_tool.insert( - documents=documents, - vector_db_id=actual_vector_db_id, - chunk_size_in_tokens=256, - ) - - files_list = client_with_empty_registry.files.list() - assert len(files_list.data) >= len(documents), ( - f"Expected at least {len(documents)} files, got {len(files_list.data)}" - ) - - vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store( - vector_store_id=actual_vector_db_id - ) - assert len(vector_store_files.data) >= len(documents), f"Expected at least {len(documents)} files in vector store" - - response = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[actual_vector_db_id], - content="Tell me about machine learning and deep learning", - ) - - assert_valid_text_response(response) - content_text = " ".join([chunk.text for chunk in response.content]).lower() - assert "machine learning" in content_text or "deep learning" in content_text - - -def test_rag_tool_exception_handling(client_with_empty_registry, embedding_model_id, embedding_dimension): - vector_db_id = "test_exception_handling" - - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model=embedding_model_id, - embedding_dimension=embedding_dimension, - ) - - available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - actual_vector_db_id = available_vector_dbs[0] - - documents = [ - Document( - document_id="valid-doc", - content="This is a valid document that should be processed successfully.", - metadata={"status": "valid"}, - ), - Document( - document_id="invalid-url-doc", - content="https://nonexistent-domain-12345.com/invalid.txt", - metadata={"status": "invalid_url"}, - ), - Document( - document_id="another-valid-doc", - content="This is another valid document for testing resilience.", - metadata={"status": "valid"}, - ), - ] - - client_with_empty_registry.tool_runtime.rag_tool.insert( - documents=documents, - vector_db_id=actual_vector_db_id, - chunk_size_in_tokens=256, - ) - - response = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[actual_vector_db_id], - content="valid document", - ) - - assert_valid_text_response(response) - content_text = " ".join([chunk.text for chunk in response.content]).lower() - assert "valid document" in content_text - - -def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension): - providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] - assert len(providers) > 0 - - vector_db_id = "test_vector_db" - - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model=embedding_model_id, - embedding_dimension=embedding_dimension, - ) - - available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - # VectorDB is being migrated to VectorStore, so the ID will be different - # Just check that at least one vector DB was registered - assert len(available_vector_dbs) > 0 - # Use the actual registered vector_db_id for subsequent operations - actual_vector_db_id = available_vector_dbs[0] - - urls = [ - "memory_optimizations.rst", - "chat.rst", - "llama3.rst", - ] - documents = [ - Document( - document_id=f"num-{i}", - content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", - mime_type="text/plain", - metadata={"author": "llama", "source": url}, - ) - for i, url in enumerate(urls) - ] - - client_with_empty_registry.tool_runtime.rag_tool.insert( - documents=documents, - vector_db_id=actual_vector_db_id, - chunk_size_in_tokens=512, - ) - - response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[actual_vector_db_id], - content="What is the name of the method used for fine-tuning?", - ) - assert_valid_text_response(response_with_metadata) - assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content) - - response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[actual_vector_db_id], - content="What is the name of the method used for fine-tuning?", - query_config={ - "include_metadata_in_content": True, - "chunk_template": "Result {index}\nContent: {chunk.content}\n", - }, - ) - assert_valid_text_response(response_without_metadata) - assert not any("metadata:" in chunk.text.lower() for chunk in response_without_metadata.content) - - with pytest.raises((ValueError, BadRequestError)): - client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[actual_vector_db_id], - content="What is the name of the method used for fine-tuning?", - query_config={ - "chunk_template": "This should raise a ValueError because it is missing the proper template variables", - }, - ) - - -def test_rag_tool_query_generation(client_with_empty_registry, embedding_model_id, embedding_dimension): - vector_db_id = "test_query_generation_db" - - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model=embedding_model_id, - embedding_dimension=embedding_dimension, - ) - - available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - actual_vector_db_id = available_vector_dbs[0] - - documents = [ - Document( - document_id="ai-doc", - content="Artificial intelligence and machine learning are transforming technology.", - metadata={"category": "AI"}, - ), - Document( - document_id="banana-doc", - content="Don't bring a banana to a knife fight.", - metadata={"category": "wisdom"}, - ), - ] - - client_with_empty_registry.tool_runtime.rag_tool.insert( - documents=documents, - vector_db_id=actual_vector_db_id, - chunk_size_in_tokens=256, - ) - - response = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[actual_vector_db_id], - content="Tell me about AI", - ) - - assert_valid_text_response(response) - content_text = " ".join([chunk.text for chunk in response.content]).lower() - assert "artificial intelligence" in content_text or "machine learning" in content_text - - -def test_rag_tool_pdf_data_url_handling(client_with_empty_registry, embedding_model_id, embedding_dimension): - vector_db_id = "test_pdf_data_url_db" - - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model=embedding_model_id, - embedding_dimension=embedding_dimension, - ) - - available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - actual_vector_db_id = available_vector_dbs[0] - - sample_pdf = b"%PDF-1.3\n3 0 obj\n<>\nendobj\n4 0 obj\n<>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<>\nendobj\n5 0 obj\n<>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Llama Stack Developers)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n" - - import base64 - - pdf_base64 = base64.b64encode(sample_pdf).decode("utf-8") - pdf_data_url = f"data:application/pdf;base64,{pdf_base64}" - - documents = [ - Document( - document_id="test-pdf-data-url", - content=pdf_data_url, - metadata={"type": "pdf", "source": "data_url"}, - ), - ] - - client_with_empty_registry.tool_runtime.rag_tool.insert( - documents=documents, - vector_db_id=actual_vector_db_id, - chunk_size_in_tokens=256, - ) - - files_list = client_with_empty_registry.files.list() - assert len(files_list.data) >= 1, "PDF should have been uploaded to Files API" - - pdf_file = None - for file in files_list.data: - if file.filename and "test-pdf-data-url" in file.filename: - pdf_file = file - break - - assert pdf_file is not None, "PDF file should be found in Files API" - assert pdf_file.bytes == len(sample_pdf), f"File size should match original PDF ({len(sample_pdf)} bytes)" - - file_content = client_with_empty_registry.files.retrieve_content(pdf_file.id) - assert file_content.startswith(b"%PDF-"), "Retrieved file should be a valid PDF" - - vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store( - vector_store_id=actual_vector_db_id - ) - assert len(vector_store_files.data) >= 1, "PDF should be attached to vector store" - - response = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[actual_vector_db_id], - content="sample title", - ) - - assert_valid_text_response(response) - content_text = " ".join([chunk.text for chunk in response.content]).lower() - assert "sample title" in content_text or "title" in content_text diff --git a/tests/unit/distribution/routing_tables/test_vector_dbs.py b/tests/unit/distribution/routing_tables/test_vector_dbs.py deleted file mode 100644 index 3444f64c2..000000000 --- a/tests/unit/distribution/routing_tables/test_vector_dbs.py +++ /dev/null @@ -1,381 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -# Unit tests for the routing tables vector_dbs - -import time -import uuid -from unittest.mock import AsyncMock - -import pytest - -from llama_stack.apis.datatypes import Api -from llama_stack.apis.models import ModelType -from llama_stack.apis.vector_dbs.vector_dbs import VectorDB -from llama_stack.apis.vector_io.vector_io import ( - VectorStoreContent, - VectorStoreDeleteResponse, - VectorStoreFileContentsResponse, - VectorStoreFileCounts, - VectorStoreFileDeleteResponse, - VectorStoreFileObject, - VectorStoreObject, - VectorStoreSearchResponsePage, -) -from llama_stack.core.access_control.datatypes import AccessRule, Scope -from llama_stack.core.datatypes import User -from llama_stack.core.request_headers import request_provider_data_context -from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable -from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable - - -class VectorDBImpl(Impl): - def __init__(self): - super().__init__(Api.vector_io) - self.vector_stores = {} - - async def register_vector_db(self, vector_db: VectorDB): - return vector_db - - async def unregister_vector_db(self, vector_db_id: str): - return vector_db_id - - async def openai_retrieve_vector_store(self, vector_store_id): - return VectorStoreObject( - id=vector_store_id, - name="Test Store", - created_at=int(time.time()), - file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), - ) - - async def openai_update_vector_store(self, vector_store_id, **kwargs): - return VectorStoreObject( - id=vector_store_id, - name="Updated Store", - created_at=int(time.time()), - file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), - ) - - async def openai_delete_vector_store(self, vector_store_id): - return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True) - - async def openai_search_vector_store(self, vector_store_id, query, **kwargs): - return VectorStoreSearchResponsePage( - object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None - ) - - async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs): - return VectorStoreFileObject( - id=file_id, - status="completed", - chunking_strategy={"type": "auto"}, - created_at=int(time.time()), - vector_store_id=vector_store_id, - ) - - async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs): - return [ - VectorStoreFileObject( - id="1", - status="completed", - chunking_strategy={"type": "auto"}, - created_at=int(time.time()), - vector_store_id=vector_store_id, - ) - ] - - async def openai_retrieve_vector_store_file(self, vector_store_id, file_id): - return VectorStoreFileObject( - id=file_id, - status="completed", - chunking_strategy={"type": "auto"}, - created_at=int(time.time()), - vector_store_id=vector_store_id, - ) - - async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id): - return VectorStoreFileContentsResponse( - file_id=file_id, - filename="Sample File name", - attributes={"key": "value"}, - content=[VectorStoreContent(type="text", text="Sample content")], - ) - - async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs): - return VectorStoreFileObject( - id=file_id, - status="completed", - chunking_strategy={"type": "auto"}, - created_at=int(time.time()), - vector_store_id=vector_store_id, - ) - - async def openai_delete_vector_store_file(self, vector_store_id, file_id): - return VectorStoreFileDeleteResponse(id=file_id, deleted=True) - - async def openai_create_vector_store( - self, - name=None, - embedding_model=None, - embedding_dimension=None, - provider_id=None, - provider_vector_db_id=None, - **kwargs, - ): - vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}" - vector_store = VectorStoreObject( - id=vector_store_id, - name=name or vector_store_id, - created_at=int(time.time()), - file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), - ) - self.vector_stores[vector_store_id] = vector_store - return vector_store - - async def openai_list_vector_stores(self, **kwargs): - from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse - - return VectorStoreListResponse( - data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None - ) - - -async def test_vectordbs_routing_table(cached_disk_dist_registry): - n = 10 - table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) - await table.initialize() - - m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) - await m_table.initialize() - await m_table.register_model( - model_id="test-model", - provider_id="test_provider", - metadata={"embedding_dimension": 128}, - model_type=ModelType.embedding, - ) - - # Register multiple vector databases and verify listing - vdb_dict = {} - for i in range(n): - vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model") - - vector_dbs = await table.list_vector_dbs() - - assert len(vector_dbs.data) == len(vdb_dict) - vector_db_ids = {v.identifier for v in vector_dbs.data} - for k in vdb_dict: - assert vdb_dict[k].identifier in vector_db_ids - for k in vdb_dict: - await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier) - - vector_dbs = await table.list_vector_dbs() - assert len(vector_dbs.data) == 0 - - -async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry): - n = 10 - impl = VectorDBImpl() - table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {}) - await table.initialize() - - m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) - await m_table.initialize() - await m_table.register_model( - model_id="test-model", - provider_id="test_provider", - metadata={"embedding_dimension": 128}, - model_type=ModelType.embedding, - ) - - vdb_dict = {} - for i in range(n): - vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model") - - vector_dbs = await table.list_vector_dbs() - vector_db_ids = {v.identifier for v in vector_dbs.data} - - vector_stores = await impl.openai_list_vector_stores() - vector_store_ids = {v.id for v in vector_stores.data} - - assert vector_db_ids == vector_store_ids, ( - f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}" - ) - - for vector_store in vector_stores.data: - vector_db = await table.get_vector_db(vector_store.id) - assert vector_store.name == vector_db.vector_db_name, ( - f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}" - ) - - for vector_db_id in vector_db_ids: - await table.unregister_vector_db(vector_db_id) - - assert len((await table.list_vector_dbs()).data) == 0 - - -async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry): - impl = VectorDBImpl() - table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {}) - await table.initialize() - - m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) - await m_table.initialize() - await m_table.register_model( - model_id="test-model", - provider_id="test_provider", - metadata={"embedding_dimension": 128}, - model_type=ModelType.embedding, - ) - - user_provided_id = "my-custom-vector-db" - await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model") - - vector_stores = await impl.openai_list_vector_stores() - assert len(vector_stores.data) == 1 - - vector_store = vector_stores.data[0] - - assert vector_store.name == user_provided_id - - assert vector_store.id.startswith("vs_") - assert vector_store.id != user_provided_id - - vector_dbs = await table.list_vector_dbs() - assert len(vector_dbs.data) == 1 - assert vector_dbs.data[0].identifier == vector_store.id - - await table.unregister_vector_db(vector_store.id) - - -async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry): - impl = VectorDBImpl() - impl.openai_retrieve_vector_store = AsyncMock(return_value="OK") - table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[]) - m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[]) - authorized_table = "vs1" - authorized_team = "team1" - unauthorized_team = "team2" - - await m_table.initialize() - await m_table.register_model( - model_id="test-model", - provider_id="test_provider", - metadata={"embedding_dimension": 128}, - model_type=ModelType.embedding, - ) - - authorized_user = User(principal="alice", attributes={"roles": [authorized_team]}) - with request_provider_data_context({}, authorized_user): - registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") - authorized_table = registered_vdb.identifier # Use the actual generated ID - - # Authorized reader - with request_provider_data_context({}, authorized_user): - res = await table.openai_retrieve_vector_store(authorized_table) - assert res == "OK" - - # Authorized updater - impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED") - with request_provider_data_context({}, authorized_user): - res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"}) - assert res == "UPDATED" - - # Unauthorized reader - unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]}) - with request_provider_data_context({}, unauthorized_user): - with pytest.raises(ValueError): - await table.openai_retrieve_vector_store(authorized_table) - - # Unauthorized updater - with request_provider_data_context({}, unauthorized_user): - with pytest.raises(ValueError): - await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"}) - - # Authorized deleter - impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED") - with request_provider_data_context({}, authorized_user): - res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1") - assert res == "DELETED" - - # Unauthorized deleter - with request_provider_data_context({}, unauthorized_user): - with pytest.raises(ValueError): - await table.openai_delete_vector_store_file(authorized_table, file_id="file1") - - -async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry): - impl = VectorDBImpl() - - policy = [ - AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"), - AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"), - ] - - table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy) - m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[]) - - vector_db_id = "vs1" - file_id = "file-1" - - admin_user = User(principal="admin", attributes={"roles": ["admin"]}) - read_only_user = User(principal="reader", attributes={"roles": ["reader"]}) - no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]}) - - await m_table.initialize() - await m_table.register_model( - model_id="test-model", - provider_id="test_provider", - metadata={"embedding_dimension": 128}, - model_type=ModelType.embedding, - ) - - with request_provider_data_context({}, admin_user): - registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") - vector_db_id = registered_vdb.identifier # Use the actual generated ID - - read_methods = [ - (table.openai_retrieve_vector_store, (vector_db_id,), {}), - (table.openai_search_vector_store, (vector_db_id, "query"), {}), - (table.openai_list_files_in_vector_store, (vector_db_id,), {}), - (table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}), - (table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}), - ] - update_methods = [ - (table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}), - (table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}), - (table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}), - ] - delete_methods = [ - (table.openai_delete_vector_store_file, (vector_db_id, file_id), {}), - (table.openai_delete_vector_store, (vector_db_id,), {}), - ] - - for user in [admin_user, read_only_user]: - with request_provider_data_context({}, user): - for method, args, kwargs in read_methods: - result = await method(*args, **kwargs) - assert result is not None, f"Read operation failed with user {user.principal}" - - with request_provider_data_context({}, no_access_user): - for method, args, kwargs in read_methods: - with pytest.raises(ValueError): - await method(*args, **kwargs) - - with request_provider_data_context({}, admin_user): - for method, args, kwargs in update_methods: - result = await method(*args, **kwargs) - assert result is not None, "Update operation failed with admin user" - - with request_provider_data_context({}, admin_user): - for method, args, kwargs in delete_methods: - result = await method(*args, **kwargs) - assert result is not None, "Delete operation failed with admin user" - - for user in [read_only_user, no_access_user]: - with request_provider_data_context({}, user): - for method, args, kwargs in delete_methods: - with pytest.raises(ValueError): - await method(*args, **kwargs)