mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 05:12:26 +00:00
Merge branch 'main' into vectordb_name
This commit is contained in:
commit
bd8c1cc071
52 changed files with 1363 additions and 921 deletions
30
.github/ISSUE_TEMPLATE/tech-debt.yml
vendored
Normal file
30
.github/ISSUE_TEMPLATE/tech-debt.yml
vendored
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
name: 🔧 Tech Debt
|
||||||
|
description: Something that is functional but should be improved or optimizied
|
||||||
|
labels: ["tech-debt"]
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: tech-debt-explanation
|
||||||
|
attributes:
|
||||||
|
label: 🤔 What is the technical debt you think should be addressed?
|
||||||
|
description: >
|
||||||
|
A clear and concise description of _what_ needs to be addressed - ensure you are describing
|
||||||
|
constitutes [technical debt](https://en.wikipedia.org/wiki/Technical_debt) and is not a bug
|
||||||
|
or feature request.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: tech-debt-motivation
|
||||||
|
attributes:
|
||||||
|
label: 💡 What is the benefit of addressing this technical debt?
|
||||||
|
description: >
|
||||||
|
A clear and concise description of _why_ this work is needed.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: other-thoughts
|
||||||
|
attributes:
|
||||||
|
label: Other thoughts
|
||||||
|
description: >
|
||||||
|
Any thoughts about how this may result in complexity in the codebase, or other trade-offs.
|
||||||
2
.github/workflows/integration-auth-tests.yml
vendored
2
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -35,7 +35,7 @@ jobs:
|
||||||
|
|
||||||
- name: Install minikube
|
- name: Install minikube
|
||||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
|
uses: medyagh/setup-minikube@e3c7f79eb1e997eabccc536a6cf318a2b0fe19d9 # v0.0.20
|
||||||
|
|
||||||
- name: Start minikube
|
- name: Start minikube
|
||||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||||
|
|
|
||||||
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
|
|
@ -89,7 +89,7 @@ jobs:
|
||||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||||
--text-model="ollama/llama3.2:3b-instruct-fp16" \
|
--text-model="ollama/llama3.2:3b-instruct-fp16" \
|
||||||
--embedding-model=all-MiniLM-L6-v2 \
|
--embedding-model=all-MiniLM-L6-v2 \
|
||||||
--safety-shield=ollama \
|
--safety-shield=$SAFETY_MODEL \
|
||||||
--color=yes \
|
--color=yes \
|
||||||
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
|
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
|
||||||
|
|
||||||
|
|
|
||||||
13
docs/_static/llama-stack-spec.html
vendored
13
docs/_static/llama-stack-spec.html
vendored
|
|
@ -14795,7 +14795,8 @@
|
||||||
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
|
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
|
||||||
},
|
},
|
||||||
"mode": {
|
"mode": {
|
||||||
"type": "string",
|
"$ref": "#/components/schemas/RAGSearchMode",
|
||||||
|
"default": "vector",
|
||||||
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
|
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
|
||||||
},
|
},
|
||||||
"ranker": {
|
"ranker": {
|
||||||
|
|
@ -14830,6 +14831,16 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"RAGSearchMode": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"vector",
|
||||||
|
"keyword",
|
||||||
|
"hybrid"
|
||||||
|
],
|
||||||
|
"title": "RAGSearchMode",
|
||||||
|
"description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results"
|
||||||
|
},
|
||||||
"RRFRanker": {
|
"RRFRanker": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
|
||||||
14
docs/_static/llama-stack-spec.yaml
vendored
14
docs/_static/llama-stack-spec.yaml
vendored
|
|
@ -10344,7 +10344,8 @@ components:
|
||||||
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
||||||
{chunk.content}\nMetadata: {metadata}\n"
|
{chunk.content}\nMetadata: {metadata}\n"
|
||||||
mode:
|
mode:
|
||||||
type: string
|
$ref: '#/components/schemas/RAGSearchMode'
|
||||||
|
default: vector
|
||||||
description: >-
|
description: >-
|
||||||
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
||||||
"vector".
|
"vector".
|
||||||
|
|
@ -10371,6 +10372,17 @@ components:
|
||||||
mapping:
|
mapping:
|
||||||
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
||||||
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
||||||
|
RAGSearchMode:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- vector
|
||||||
|
- keyword
|
||||||
|
- hybrid
|
||||||
|
title: RAGSearchMode
|
||||||
|
description: >-
|
||||||
|
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
|
||||||
|
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
|
||||||
|
- HYBRID: Combines both vector and keyword search for better results
|
||||||
RRFRanker:
|
RRFRanker:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
||||||
|
|
@ -145,6 +145,10 @@ $ llama stack build --template starter
|
||||||
...
|
...
|
||||||
You can now edit ~/.llama/distributions/llamastack-starter/starter-run.yaml and run `llama stack run ~/.llama/distributions/llamastack-starter/starter-run.yaml`
|
You can now edit ~/.llama/distributions/llamastack-starter/starter-run.yaml and run `llama stack run ~/.llama/distributions/llamastack-starter/starter-run.yaml`
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```{tip}
|
||||||
|
The generated `run.yaml` file is a starting point for your configuration. For comprehensive guidance on customizing it for your specific needs, infrastructure, and deployment scenarios, see [Customizing Your run.yaml Configuration](customizing_run_yaml.md).
|
||||||
|
```
|
||||||
:::
|
:::
|
||||||
:::{tab-item} Building from Scratch
|
:::{tab-item} Building from Scratch
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,10 @@
|
||||||
|
|
||||||
The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution:
|
The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution:
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
The default `run.yaml` files generated by templates are starting points for your configuration. For guidance on customizing these files for your specific needs, see [Customizing Your run.yaml Configuration](customizing_run_yaml.md).
|
||||||
|
```
|
||||||
|
|
||||||
```{dropdown} 👋 Click here for a Sample Configuration File
|
```{dropdown} 👋 Click here for a Sample Configuration File
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|
|
||||||
40
docs/source/distributions/customizing_run_yaml.md
Normal file
40
docs/source/distributions/customizing_run_yaml.md
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Customizing run.yaml Files
|
||||||
|
|
||||||
|
The `run.yaml` files generated by Llama Stack templates are **starting points** designed to be customized for your specific needs. They are not meant to be used as-is in production environments.
|
||||||
|
|
||||||
|
## Key Points
|
||||||
|
|
||||||
|
- **Templates are starting points**: Generated `run.yaml` files contain defaults for development/testing
|
||||||
|
- **Customization expected**: Update URLs, credentials, models, and settings for your environment
|
||||||
|
- **Version control separately**: Keep customized configs in your own repository
|
||||||
|
- **Environment-specific**: Create different configurations for dev, staging, production
|
||||||
|
|
||||||
|
## What You Can Customize
|
||||||
|
|
||||||
|
You can customize:
|
||||||
|
- **Provider endpoints**: Change `http://localhost:8000` to your actual servers
|
||||||
|
- **Swap providers**: Replace default providers (e.g., swap Tavily with Brave for search)
|
||||||
|
- **Storage paths**: Move from `/tmp/` to production directories
|
||||||
|
- **Authentication**: Add API keys, SSL, timeouts
|
||||||
|
- **Models**: Different model sizes for dev vs prod
|
||||||
|
- **Database settings**: Switch from SQLite to PostgreSQL
|
||||||
|
- **Tool configurations**: Add custom tools and integrations
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
- Use environment variables for secrets and environment-specific values
|
||||||
|
- Create separate `run.yaml` files for different environments (dev, staging, prod)
|
||||||
|
- Document your changes with comments
|
||||||
|
- Test configurations before deployment
|
||||||
|
- Keep your customized configs in version control
|
||||||
|
|
||||||
|
Example structure:
|
||||||
|
```
|
||||||
|
your-project/
|
||||||
|
├── configs/
|
||||||
|
│ ├── dev-run.yaml
|
||||||
|
│ ├── prod-run.yaml
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
The goal is to take the generated template and adapt it to your specific infrastructure and operational needs.
|
||||||
|
|
@ -9,6 +9,7 @@ This section provides an overview of the distributions available in Llama Stack.
|
||||||
|
|
||||||
importing_as_library
|
importing_as_library
|
||||||
configuration
|
configuration
|
||||||
|
customizing_run_yaml
|
||||||
list_of_distributions
|
list_of_distributions
|
||||||
kubernetes_deployment
|
kubernetes_deployment
|
||||||
building_distro
|
building_distro
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ Llama Stack is a server that exposes multiple APIs, you connect with it using th
|
||||||
You can use Python to build and run the Llama Stack server, which is useful for testing and development.
|
You can use Python to build and run the Llama Stack server, which is useful for testing and development.
|
||||||
|
|
||||||
Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup,
|
Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup,
|
||||||
which defines the providers and their settings.
|
which defines the providers and their settings. The generated configuration serves as a starting point that you can [customize for your specific needs](../distributions/customizing_run_yaml.md).
|
||||||
Now let's build and run the Llama Stack config for Ollama.
|
Now let's build and run the Llama Stack config for Ollama.
|
||||||
We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables.
|
We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables.
|
||||||
|
|
||||||
|
|
@ -77,7 +77,7 @@ ENABLE_OLLAMA=ollama INFERENCE_MODEL="llama3.2:3b" llama stack build --template
|
||||||
You can use a container image to run the Llama Stack server. We provide several container images for the server
|
You can use a container image to run the Llama Stack server. We provide several container images for the server
|
||||||
component that works with different inference providers out of the box. For this guide, we will use
|
component that works with different inference providers out of the box. For this guide, we will use
|
||||||
`llamastack/distribution-starter` as the container image. If you'd like to build your own image or customize the
|
`llamastack/distribution-starter` as the container image. If you'd like to build your own image or customize the
|
||||||
configurations, please check out [this guide](../references/index.md).
|
configurations, please check out [this guide](../distributions/building_distro.md).
|
||||||
First lets setup some environment variables and create a local directory to mount into the container’s file system.
|
First lets setup some environment variables and create a local directory to mount into the container’s file system.
|
||||||
```bash
|
```bash
|
||||||
export INFERENCE_MODEL="llama3.2:3b"
|
export INFERENCE_MODEL="llama3.2:3b"
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
||||||
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
|
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
|
||||||
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
|
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
|
||||||
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
||||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
|
||||||
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
|
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
|
||||||
|
|
||||||
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
||||||
|
|
@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
||||||
```yaml
|
```yaml
|
||||||
uri: ${env.MILVUS_ENDPOINT}
|
uri: ${env.MILVUS_ENDPOINT}
|
||||||
token: ${env.MILVUS_TOKEN}
|
token: ${env.MILVUS_TOKEN}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
||||||
| `db` | `str \| None` | No | postgres | |
|
| `db` | `str \| None` | No | postgres | |
|
||||||
| `user` | `str \| None` | No | postgres | |
|
| `user` | `str \| None` | No | postgres | |
|
||||||
| `password` | `str \| None` | No | mysecretpassword | |
|
| `password` | `str \| None` | No | mysecretpassword | |
|
||||||
|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
@ -49,6 +50,9 @@ port: ${env.PGVECTOR_PORT:=5432}
|
||||||
db: ${env.PGVECTOR_DB}
|
db: ${env.PGVECTOR_DB}
|
||||||
user: ${env.PGVECTOR_USER}
|
user: ${env.PGVECTOR_USER}
|
||||||
password: ${env.PGVECTOR_PASSWORD}
|
password: ${env.PGVECTOR_PASSWORD}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/pgvector_registry.db
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,9 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
{}
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,20 @@ class RAGQueryGenerator(Enum):
|
||||||
custom = "custom"
|
custom = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RAGSearchMode(Enum):
|
||||||
|
"""
|
||||||
|
Search modes for RAG query retrieval:
|
||||||
|
- VECTOR: Uses vector similarity search for semantic matching
|
||||||
|
- KEYWORD: Uses keyword-based search for exact matching
|
||||||
|
- HYBRID: Combines both vector and keyword search for better results
|
||||||
|
"""
|
||||||
|
|
||||||
|
VECTOR = "vector"
|
||||||
|
KEYWORD = "keyword"
|
||||||
|
HYBRID = "hybrid"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||||
type: Literal["default"] = "default"
|
type: Literal["default"] = "default"
|
||||||
|
|
@ -128,7 +142,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
mode: str | None = None
|
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
|
||||||
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
||||||
|
|
||||||
@field_validator("chunk_template")
|
@field_validator("chunk_template")
|
||||||
|
|
|
||||||
|
|
@ -181,8 +181,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
|
|
||||||
# Load existing OpenAI vector stores using the mixin method
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# Cleanup if needed
|
# Cleanup if needed
|
||||||
|
|
@ -261,42 +261,10 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
# OpenAI Vector Store Mixin abstract method implementations
|
|
||||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Save vector store metadata to kvstore."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
|
|
||||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
|
||||||
"""Load all vector store metadata from kvstore."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
|
||||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
|
||||||
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
|
|
||||||
|
|
||||||
stores = {}
|
|
||||||
for store_data in stored_openai_stores:
|
|
||||||
store_info = json.loads(store_data)
|
|
||||||
stores[store_info["id"]] = store_info
|
|
||||||
return stores
|
|
||||||
|
|
||||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Update vector store metadata in kvstore."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
|
|
||||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
|
||||||
"""Delete vector store metadata from kvstore."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.delete(key)
|
|
||||||
|
|
||||||
async def _save_openai_vector_store_file(
|
async def _save_openai_vector_store_file(
|
||||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save vector store file metadata to kvstore."""
|
"""Save vector store file data to kvstore."""
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||||
|
|
@ -324,7 +292,16 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||||
|
|
||||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||||
"""Delete vector store file metadata from kvstore."""
|
"""Delete vector store data from kvstore."""
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
|
||||||
await self.kvstore.delete(key)
|
keys_to_delete = [
|
||||||
|
f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}",
|
||||||
|
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}",
|
||||||
|
]
|
||||||
|
for key in keys_to_delete:
|
||||||
|
try:
|
||||||
|
await self.kvstore.delete(key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete key {key}: {e}")
|
||||||
|
continue
|
||||||
|
|
|
||||||
|
|
@ -452,8 +452,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
# load any existing OpenAI vector stores
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# nothing to do since we don't maintain a persistent connection
|
# nothing to do since we don't maintain a persistent connection
|
||||||
|
|
@ -501,41 +501,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
await self.cache[vector_db_id].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
del self.cache[vector_db_id]
|
del self.cache[vector_db_id]
|
||||||
|
|
||||||
# OpenAI Vector Store Mixin abstract method implementations
|
|
||||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Save vector store metadata to SQLite database."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
|
||||||
"""Load all vector store metadata from SQLite database."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
|
||||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
|
||||||
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
|
|
||||||
stores = {}
|
|
||||||
for store_data in stored_openai_stores:
|
|
||||||
store_info = json.loads(store_data)
|
|
||||||
stores[store_info["id"]] = store_info
|
|
||||||
return stores
|
|
||||||
|
|
||||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Update vector store metadata in SQLite database."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
|
||||||
"""Delete vector store metadata from SQLite database."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.delete(key)
|
|
||||||
if store_id in self.openai_vector_stores:
|
|
||||||
del self.openai_vector_stores[store_id]
|
|
||||||
|
|
||||||
async def _save_openai_vector_store_file(
|
async def _save_openai_vector_store_file(
|
||||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_entry,
|
build_model_entry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
SAFETY_MODELS_ENTRIES = [
|
||||||
|
# The Llama Guard models don't have their full fp16 versions
|
||||||
|
# so we are going to alias their default version to the canonical SKU
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"llama-guard3:8b",
|
||||||
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"llama-guard3:1b",
|
||||||
|
CoreModelId.llama_guard_3_1b.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
MODEL_ENTRIES = [
|
MODEL_ENTRIES = [
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"llama3.1:8b-instruct-fp16",
|
"llama3.1:8b-instruct-fp16",
|
||||||
|
|
@ -73,16 +86,6 @@ MODEL_ENTRIES = [
|
||||||
"llama3.3:70b",
|
"llama3.3:70b",
|
||||||
CoreModelId.llama3_3_70b_instruct.value,
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
),
|
),
|
||||||
# The Llama Guard models don't have their full fp16 versions
|
|
||||||
# so we are going to alias their default version to the canonical SKU
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"llama-guard3:8b",
|
|
||||||
CoreModelId.llama_guard_3_8b.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"llama-guard3:1b",
|
|
||||||
CoreModelId.llama_guard_3_1b.value,
|
|
||||||
),
|
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="all-minilm:l6-v2",
|
provider_model_id="all-minilm:l6-v2",
|
||||||
aliases=["all-minilm"],
|
aliases=["all-minilm"],
|
||||||
|
|
@ -100,4 +103,4 @@ MODEL_ENTRIES = [
|
||||||
"context_length": 8192,
|
"context_length": 8192,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
]
|
] + SAFETY_MODELS_ENTRIES
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
|
||||||
uri: str = Field(description="The URI of the Milvus server")
|
uri: str = Field(description="The URI of the Milvus server")
|
||||||
token: str | None = Field(description="The token of the Milvus server")
|
token: str | None = Field(description="The token of the Milvus server")
|
||||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||||
|
|
||||||
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
||||||
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
||||||
|
|
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
|
return {
|
||||||
|
"uri": "${env.MILVUS_ENDPOINT}",
|
||||||
|
"token": "${env.MILVUS_TOKEN}",
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="milvus_remote_registry.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import DataType, MilvusClient
|
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
||||||
|
|
||||||
from llama_stack.apis.files.files import Files
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
|
|
@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||||
|
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||||
|
# Create schema for vector search
|
||||||
|
schema = self.client.create_schema()
|
||||||
|
schema.add_field(
|
||||||
|
field_name="chunk_id",
|
||||||
|
datatype=DataType.VARCHAR,
|
||||||
|
is_primary=True,
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="content",
|
||||||
|
datatype=DataType.VARCHAR,
|
||||||
|
max_length=65535,
|
||||||
|
enable_analyzer=True, # Enable text analysis for BM25
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="vector",
|
||||||
|
datatype=DataType.FLOAT_VECTOR,
|
||||||
|
dim=len(embeddings[0]),
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="chunk_content",
|
||||||
|
datatype=DataType.JSON,
|
||||||
|
)
|
||||||
|
# Add sparse vector field for BM25 (required by the function)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="sparse",
|
||||||
|
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
index_params = self.client.prepare_index_params()
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="vector",
|
||||||
|
index_type="FLAT",
|
||||||
|
metric_type="COSINE",
|
||||||
|
)
|
||||||
|
# Add index for sparse field (required by BM25 function)
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="sparse",
|
||||||
|
index_type="SPARSE_INVERTED_INDEX",
|
||||||
|
metric_type="BM25",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add BM25 function for full-text search
|
||||||
|
bm25_function = Function(
|
||||||
|
name="text_bm25_emb",
|
||||||
|
input_field_names=["content"],
|
||||||
|
output_field_names=["sparse"],
|
||||||
|
function_type=FunctionType.BM25,
|
||||||
|
)
|
||||||
|
schema.add_function(bm25_function)
|
||||||
|
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self.client.create_collection,
|
self.client.create_collection,
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
dimension=len(embeddings[0]),
|
schema=schema,
|
||||||
auto_id=True,
|
index_params=index_params,
|
||||||
consistency_level=self.consistency_level,
|
consistency_level=self.consistency_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
data.append(
|
data.append(
|
||||||
{
|
{
|
||||||
"chunk_id": chunk.chunk_id,
|
"chunk_id": chunk.chunk_id,
|
||||||
|
"content": chunk.content,
|
||||||
"vector": embedding,
|
"vector": embedding,
|
||||||
"chunk_content": chunk.model_dump(),
|
"chunk_content": chunk.model_dump(),
|
||||||
|
# sparse field will be handled by BM25 function automatically
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
self.client.search,
|
self.client.search,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[embedding],
|
data=[embedding],
|
||||||
|
anns_field="vector",
|
||||||
limit=k,
|
limit=k,
|
||||||
output_fields=["*"],
|
output_fields=["*"],
|
||||||
search_params={"params": {"radius": score_threshold}},
|
search_params={"params": {"radius": score_threshold}},
|
||||||
|
|
@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
"""
|
||||||
|
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use Milvus's built-in BM25 search
|
||||||
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.search,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
data=[query_string], # Raw text query
|
||||||
|
anns_field="sparse", # Use sparse field for BM25
|
||||||
|
output_fields=["chunk_content"], # Output the chunk content
|
||||||
|
limit=k,
|
||||||
|
search_params={
|
||||||
|
"params": {
|
||||||
|
"drop_ratio_search": 0.2, # Ignore low-importance terms
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for res in search_res[0]:
|
||||||
|
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(res["distance"]) # BM25 score from Milvus
|
||||||
|
|
||||||
|
# Filter by score threshold
|
||||||
|
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||||
|
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||||
|
|
||||||
|
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error performing BM25 search: {e}")
|
||||||
|
# Fallback to simple text search
|
||||||
|
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||||
|
|
||||||
|
async def _fallback_keyword_search(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
Fallback to simple text search when BM25 search is not available.
|
||||||
|
"""
|
||||||
|
# Simple text search using content field
|
||||||
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.query,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
filter='content like "%{content}%"',
|
||||||
|
filter_params={"content": query_string},
|
||||||
|
output_fields=["*"],
|
||||||
|
limit=k,
|
||||||
|
)
|
||||||
|
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
|
||||||
|
scores = [1.0] * len(chunks) # Simple binary score for text search
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self,
|
self,
|
||||||
|
|
@ -179,7 +293,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
uri = os.path.expanduser(self.config.db_path)
|
uri = os.path.expanduser(self.config.db_path)
|
||||||
self.client = MilvusClient(uri=uri)
|
self.client = MilvusClient(uri=uri)
|
||||||
|
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
@ -246,38 +361,16 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
|
||||||
|
if params and params.get("mode") == "keyword":
|
||||||
|
# Check if this is inline Milvus (Milvus-Lite)
|
||||||
|
if hasattr(self.config, "db_path"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Keyword search is not supported in Milvus-Lite. "
|
||||||
|
"Please use a remote Milvus server for keyword search functionality."
|
||||||
|
)
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Save vector store metadata to persistent storage."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Update vector store metadata in persistent storage."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
|
||||||
"""Delete vector store metadata from persistent storage."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.delete(key)
|
|
||||||
if store_id in self.openai_vector_stores:
|
|
||||||
del self.openai_vector_stores[store_id]
|
|
||||||
|
|
||||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
|
||||||
"""Load all vector store metadata from persistent storage."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
|
||||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
|
||||||
stored = await self.kvstore.values_in_range(start_key, end_key)
|
|
||||||
return {json.loads(s)["id"]: json.loads(s) for s in stored}
|
|
||||||
|
|
||||||
async def _save_openai_vector_store_file(
|
async def _save_openai_vector_store_file(
|
||||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,10 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
KVStoreConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
)
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,10 +22,12 @@ class PGVectorVectorIOConfig(BaseModel):
|
||||||
db: str | None = Field(default="postgres")
|
db: str | None = Field(default="postgres")
|
||||||
user: str | None = Field(default="postgres")
|
user: str | None = Field(default="postgres")
|
||||||
password: str | None = Field(default="mysecretpassword")
|
password: str | None = Field(default="mysecretpassword")
|
||||||
|
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
|
__distro_dir__: str,
|
||||||
host: str = "${env.PGVECTOR_HOST:=localhost}",
|
host: str = "${env.PGVECTOR_HOST:=localhost}",
|
||||||
port: int = "${env.PGVECTOR_PORT:=5432}",
|
port: int = "${env.PGVECTOR_PORT:=5432}",
|
||||||
db: str = "${env.PGVECTOR_DB}",
|
db: str = "${env.PGVECTOR_DB}",
|
||||||
|
|
@ -29,4 +35,14 @@ class PGVectorVectorIOConfig(BaseModel):
|
||||||
password: str = "${env.PGVECTOR_PASSWORD}",
|
password: str = "${env.PGVECTOR_PASSWORD}",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {"host": host, "port": port, "db": db, "user": user, "password": password}
|
return {
|
||||||
|
"host": host,
|
||||||
|
"port": port,
|
||||||
|
"db": db,
|
||||||
|
"user": user,
|
||||||
|
"password": password,
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="pgvector_registry.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,24 +13,18 @@ from psycopg2 import sql
|
||||||
from psycopg2.extras import Json, execute_values
|
from psycopg2.extras import Json, execute_values
|
||||||
from pydantic import BaseModel, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.inference import InterleavedContent
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
SearchRankingOptions,
|
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreChunkingStrategy,
|
|
||||||
VectorStoreDeleteResponse,
|
|
||||||
VectorStoreFileContentsResponse,
|
|
||||||
VectorStoreFileObject,
|
|
||||||
VectorStoreFileStatus,
|
|
||||||
VectorStoreListFilesResponse,
|
|
||||||
VectorStoreListResponse,
|
|
||||||
VectorStoreObject,
|
|
||||||
VectorStoreSearchResponsePage,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
|
@ -40,6 +34,13 @@ from .config import PGVectorVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VERSION = "v3"
|
||||||
|
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||||
|
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
|
||||||
|
|
||||||
|
|
||||||
def check_extension_version(cur):
|
def check_extension_version(cur):
|
||||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||||
|
|
@ -69,7 +70,7 @@ def load_models(cur, cls):
|
||||||
|
|
||||||
|
|
||||||
class PGVectorIndex(EmbeddingIndex):
|
class PGVectorIndex(EmbeddingIndex):
|
||||||
def __init__(self, vector_db: VectorDB, dimension: int, conn):
|
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
# Sanitize the table name by replacing hyphens with underscores
|
# Sanitize the table name by replacing hyphens with underscores
|
||||||
|
|
@ -77,6 +78,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||||
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
||||||
self.table_name = f"vector_store_{sanitized_identifier}"
|
self.table_name = f"vector_store_{sanitized_identifier}"
|
||||||
|
self.kvstore = kvstore
|
||||||
|
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"""
|
f"""
|
||||||
|
|
@ -158,15 +160,28 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
||||||
|
|
||||||
class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PGVectorVectorIOConfig,
|
||||||
|
inference_api: Api.inference,
|
||||||
|
files_api: Files | None = None,
|
||||||
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
self.files_api = files_api
|
||||||
|
self.kvstore: KVStore | None = None
|
||||||
|
self.vector_db_store = None
|
||||||
|
self.openai_vector_store: dict[str, dict[str, Any]] = {}
|
||||||
|
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||||
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.conn = psycopg2.connect(
|
self.conn = psycopg2.connect(
|
||||||
host=self.config.host,
|
host=self.config.host,
|
||||||
|
|
@ -201,14 +216,31 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
log.info("Connection to PGVector database server closed")
|
log.info("Connection to PGVector database server closed")
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
|
# Persist vector DB metadata in the KV store
|
||||||
|
assert self.kvstore is not None
|
||||||
|
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||||
|
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||||
|
|
||||||
|
# Upsert model metadata in Postgres
|
||||||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||||
|
|
||||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
# Create and cache the PGVector index table for the vector DB
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
index = VectorDBWithIndex(
|
||||||
|
vector_db,
|
||||||
|
index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
self.cache[vector_db.identifier] = index
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
await self.cache[vector_db_id].index.delete()
|
# Remove provider index and cache
|
||||||
del self.cache[vector_db_id]
|
if vector_db_id in self.cache:
|
||||||
|
await self.cache[vector_db_id].index.delete()
|
||||||
|
del self.cache[vector_db_id]
|
||||||
|
|
||||||
|
# Delete vector DB metadata from KV store
|
||||||
|
assert self.kvstore is not None
|
||||||
|
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||||
|
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
|
|
@ -237,106 +269,20 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
return self.cache[vector_db_id]
|
return self.cache[vector_db_id]
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
# OpenAI Vector Stores File operations are not supported in PGVector
|
||||||
self,
|
async def _save_openai_vector_store_file(
|
||||||
name: str,
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
file_ids: list[str] | None = None,
|
) -> None:
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
embedding_model: str | None = None,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||||
self,
|
|
||||||
limit: int | None = 20,
|
|
||||||
order: str | None = "desc",
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
) -> VectorStoreListResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||||
|
|
||||||
async def openai_retrieve_vector_store(
|
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||||
|
|
||||||
async def openai_update_vector_store(
|
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
name: str | None = None,
|
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||||
|
|
||||||
async def openai_delete_vector_store(
|
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreDeleteResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_search_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
query: str | list[str],
|
|
||||||
filters: dict[str, Any] | None = None,
|
|
||||||
max_num_results: int | None = 10,
|
|
||||||
ranking_options: SearchRankingOptions | None = None,
|
|
||||||
rewrite_query: bool | None = False,
|
|
||||||
search_mode: str | None = "vector",
|
|
||||||
) -> VectorStoreSearchResponsePage:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_attach_file_to_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
attributes: dict[str, Any] | None = None,
|
|
||||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
|
||||||
) -> VectorStoreFileObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_list_files_in_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
limit: int | None = 20,
|
|
||||||
order: str | None = "desc",
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
filter: VectorStoreFileStatus | None = None,
|
|
||||||
) -> VectorStoreListFilesResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_retrieve_vector_store_file(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
) -> VectorStoreFileObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_retrieve_vector_store_file_contents(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
) -> VectorStoreFileContentsResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_update_vector_store_file(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
attributes: dict[str, Any] | None = None,
|
|
||||||
) -> VectorStoreFileObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_delete_vector_store_file(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
) -> VectorStoreFileObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||||
|
|
|
||||||
|
|
@ -6,15 +6,26 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
KVStoreConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateRequestProviderData(BaseModel):
|
class WeaviateRequestProviderData(BaseModel):
|
||||||
weaviate_api_key: str
|
weaviate_api_key: str
|
||||||
weaviate_cluster_url: str
|
weaviate_cluster_url: str
|
||||||
|
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorIOConfig(BaseModel):
|
class WeaviateVectorIOConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {}
|
return {
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="weaviate_registry.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,13 @@ from weaviate.classes.init import Auth
|
||||||
from weaviate.classes.query import Filter
|
from weaviate.classes.query import Filter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
|
@ -27,11 +30,19 @@ from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VERSION = "v3"
|
||||||
|
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||||
|
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
|
||||||
|
|
||||||
|
|
||||||
class WeaviateIndex(EmbeddingIndex):
|
class WeaviateIndex(EmbeddingIndex):
|
||||||
def __init__(self, client: weaviate.Client, collection_name: str):
|
def __init__(self, client: weaviate.Client, collection_name: str, kvstore: KVStore | None = None):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
|
self.kvstore = kvstore
|
||||||
|
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
|
|
@ -109,11 +120,21 @@ class WeaviateVectorIOAdapter(
|
||||||
NeedsRequestProviderData,
|
NeedsRequestProviderData,
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: WeaviateVectorIOConfig,
|
||||||
|
inference_api: Api.inference,
|
||||||
|
files_api: Files | None,
|
||||||
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.client_cache = {}
|
self.client_cache = {}
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
self.files_api = files_api
|
||||||
|
self.kvstore: KVStore | None = None
|
||||||
|
self.vector_db_store = None
|
||||||
|
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||||
|
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
def _get_client(self) -> weaviate.Client:
|
def _get_client(self) -> weaviate.Client:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
|
|
@ -132,7 +153,26 @@ class WeaviateVectorIOAdapter(
|
||||||
return client
|
return client
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
"""Set up KV store and load existing vector DBs and OpenAI vector stores."""
|
||||||
|
# Initialize KV store for metadata
|
||||||
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
|
||||||
|
# Load existing vector DB definitions
|
||||||
|
start_key = VECTOR_DBS_PREFIX
|
||||||
|
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||||
|
stored = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
for raw in stored:
|
||||||
|
vector_db = VectorDB.model_validate_json(raw)
|
||||||
|
client = self._get_client()
|
||||||
|
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
|
||||||
|
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||||
|
vector_db=vector_db,
|
||||||
|
index=idx,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load OpenAI vector stores metadata into cache
|
||||||
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for client in self.client_cache.values():
|
for client in self.client_cache.values():
|
||||||
|
|
@ -206,3 +246,21 @@ class WeaviateVectorIOAdapter(
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
|
# OpenAI Vector Stores File operations are not supported in Weaviate
|
||||||
|
async def _save_openai_vector_store_file(
|
||||||
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
|
||||||
|
|
@ -83,9 +83,37 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
def get_llama_model(self, provider_model_id: str) -> str | None:
|
def get_llama_model(self, provider_model_id: str) -> str | None:
|
||||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||||
|
|
||||||
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a specific model is available from the provider (non-static check).
|
||||||
|
|
||||||
|
This is for subclassing purposes, so providers can check if a specific
|
||||||
|
model is currently available for use through dynamic means (e.g., API calls).
|
||||||
|
|
||||||
|
This method should NOT check statically configured model entries in
|
||||||
|
`self.alias_to_provider_id_map` - that is handled separately in register_model.
|
||||||
|
|
||||||
|
Default implementation returns False (no dynamic models available).
|
||||||
|
|
||||||
|
:param model: The model identifier to check.
|
||||||
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
|
# Check if model is supported in static configuration
|
||||||
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
|
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
|
# If not found in static config, check if it's available dynamically from provider
|
||||||
|
if not supported_model_id:
|
||||||
|
if await self.check_model_availability(model.provider_resource_id):
|
||||||
|
supported_model_id = model.provider_resource_id
|
||||||
|
else:
|
||||||
|
# note: we cannot provide a complete list of supported models without
|
||||||
|
# getting a complete list from the provider, so we return "..."
|
||||||
|
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
|
||||||
|
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
|
||||||
|
|
||||||
provider_resource_id = self.get_provider_model_id(model.model_id)
|
provider_resource_id = self.get_provider_model_id(model.model_id)
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||||
|
|
@ -114,6 +142,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Register the model alias, ensuring it maps to the correct provider model id
|
||||||
self.alias_to_provider_id_map[model.model_id] = supported_model_id
|
self.alias_to_provider_id_map[model.model_id] = supported_model_id
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import time
|
import time
|
||||||
|
|
@ -35,6 +36,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreSearchResponse,
|
VectorStoreSearchResponse,
|
||||||
VectorStoreSearchResponsePage,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -59,26 +61,45 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
# These should be provided by the implementing class
|
# These should be provided by the implementing class
|
||||||
openai_vector_stores: dict[str, dict[str, Any]]
|
openai_vector_stores: dict[str, dict[str, Any]]
|
||||||
files_api: Files | None
|
files_api: Files | None
|
||||||
|
# KV store for persisting OpenAI vector store metadata
|
||||||
|
kvstore: KVStore | None
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
"""Save vector store metadata to persistent storage."""
|
"""Save vector store metadata to persistent storage."""
|
||||||
pass
|
assert self.kvstore is not None
|
||||||
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
|
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||||
|
# update in-memory cache
|
||||||
|
self.openai_vector_stores[store_id] = store_info
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
"""Load all vector store metadata from persistent storage."""
|
"""Load all vector store metadata from persistent storage."""
|
||||||
pass
|
assert self.kvstore is not None
|
||||||
|
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||||
|
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||||
|
stored_data = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
|
stores: dict[str, dict[str, Any]] = {}
|
||||||
|
for item in stored_data:
|
||||||
|
info = json.loads(item)
|
||||||
|
stores[info["id"]] = info
|
||||||
|
return stores
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
"""Update vector store metadata in persistent storage."""
|
"""Update vector store metadata in persistent storage."""
|
||||||
pass
|
assert self.kvstore is not None
|
||||||
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
|
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||||
|
# update in-memory cache
|
||||||
|
self.openai_vector_stores[store_id] = store_info
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||||
"""Delete vector store metadata from persistent storage."""
|
"""Delete vector store metadata from persistent storage."""
|
||||||
pass
|
assert self.kvstore is not None
|
||||||
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
|
await self.kvstore.delete(key)
|
||||||
|
# remove from in-memory cache
|
||||||
|
self.openai_vector_stores.pop(store_id, None)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _save_openai_vector_store_file(
|
async def _save_openai_vector_store_file(
|
||||||
|
|
@ -117,6 +138,10 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
"""Unregister a vector database (provider-specific implementation)."""
|
"""Unregister a vector database (provider-specific implementation)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def initialize_openai_vector_stores(self) -> None:
|
||||||
|
"""Load existing OpenAI vector stores into the in-memory cache."""
|
||||||
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models, _ = get_model_registry(available_models)
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="nvidia",
|
name="nvidia",
|
||||||
distro_type="self_hosted",
|
distro_type="self_hosted",
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_id="${env.ENABLE_PGVECTOR:+pgvector}",
|
provider_id="${env.ENABLE_PGVECTOR:+pgvector}",
|
||||||
provider_type="remote::pgvector",
|
provider_type="remote::pgvector",
|
||||||
config=PGVectorVectorIOConfig.sample_run_config(
|
config=PGVectorVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}",
|
||||||
db="${env.PGVECTOR_DB:=}",
|
db="${env.PGVECTOR_DB:=}",
|
||||||
user="${env.PGVECTOR_USER:=}",
|
user="${env.PGVECTOR_USER:=}",
|
||||||
password="${env.PGVECTOR_PASSWORD:=}",
|
password="${env.PGVECTOR_PASSWORD:=}",
|
||||||
|
|
@ -146,7 +147,8 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
default_models = get_model_registry(available_models) + [
|
models, _ = get_model_registry(available_models)
|
||||||
|
default_models = models + [
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id="meta-llama/Llama-3.3-70B-Instruct",
|
model_id="meta-llama/Llama-3.3-70B-Instruct",
|
||||||
provider_id="groq",
|
provider_id="groq",
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,9 @@ providers:
|
||||||
db: ${env.PGVECTOR_DB:=}
|
db: ${env.PGVECTOR_DB:=}
|
||||||
user: ${env.PGVECTOR_USER:=}
|
user: ${env.PGVECTOR_USER:=}
|
||||||
password: ${env.PGVECTOR_PASSWORD:=}
|
password: ${env.PGVECTOR_PASSWORD:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/pgvector_registry.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
||||||
|
|
@ -166,6 +166,9 @@ providers:
|
||||||
db: ${env.PGVECTOR_DB:=}
|
db: ${env.PGVECTOR_DB:=}
|
||||||
user: ${env.PGVECTOR_USER:=}
|
user: ${env.PGVECTOR_USER:=}
|
||||||
password: ${env.PGVECTOR_PASSWORD:=}
|
password: ${env.PGVECTOR_PASSWORD:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
@ -1171,24 +1174,8 @@ models:
|
||||||
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
|
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.ENABLE_OLLAMA:=__disabled__}
|
- shield_id: ${env.SAFETY_MODEL:=__disabled__}
|
||||||
provider_id: llama-guard
|
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
|
||||||
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
|
|
||||||
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
|
|
||||||
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
|
|
||||||
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
|
|
||||||
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
|
|
||||||
- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
|
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import (
|
||||||
ModelInput,
|
ModelInput,
|
||||||
Provider,
|
Provider,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
ShieldInput,
|
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
@ -32,75 +31,39 @@ from llama_stack.providers.registry.inference import available_providers
|
||||||
from llama_stack.providers.remote.inference.anthropic.models import (
|
from llama_stack.providers.remote.inference.anthropic.models import (
|
||||||
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.anthropic.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
from llama_stack.providers.remote.inference.bedrock.models import (
|
||||||
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
from llama_stack.providers.remote.inference.cerebras.models import (
|
||||||
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
from llama_stack.providers.remote.inference.databricks.databricks import (
|
||||||
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
|
||||||
SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
from llama_stack.providers.remote.inference.fireworks.models import (
|
||||||
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.gemini.models import (
|
from llama_stack.providers.remote.inference.gemini.models import (
|
||||||
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
|
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.gemini.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.groq.models import (
|
from llama_stack.providers.remote.inference.groq.models import (
|
||||||
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.groq.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
from llama_stack.providers.remote.inference.nvidia.models import (
|
||||||
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.openai.models import (
|
from llama_stack.providers.remote.inference.openai.models import (
|
||||||
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.openai.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
from llama_stack.providers.remote.inference.runpod.runpod import (
|
||||||
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
|
||||||
SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
from llama_stack.providers.remote.inference.sambanova.models import (
|
||||||
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.together.models import (
|
from llama_stack.providers.remote.inference.together.models import (
|
||||||
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
|
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.together.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
PGVectorVectorIOConfig,
|
PGVectorVectorIOConfig,
|
||||||
|
|
@ -111,6 +74,7 @@ from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
get_model_registry,
|
get_model_registry,
|
||||||
|
get_shield_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
|
||||||
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
||||||
"""Get model entries for a specific provider type."""
|
"""Get model entries for a specific provider type."""
|
||||||
safety_model_entries_map = {
|
safety_model_entries_map = {
|
||||||
"openai": OPENAI_SAFETY_MODELS_ENTRIES,
|
"ollama": [
|
||||||
"fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
"together": TOGETHER_SAFETY_MODELS_ENTRIES,
|
|
||||||
"anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
|
|
||||||
"gemini": GEMINI_SAFETY_MODELS_ENTRIES,
|
|
||||||
"groq": GROQ_SAFETY_MODELS_ENTRIES,
|
|
||||||
"sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
|
|
||||||
"cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
|
|
||||||
"bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
|
|
||||||
"databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
"nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
|
|
||||||
"runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Special handling for providers with dynamic model entries
|
|
||||||
if provider_type == "ollama":
|
|
||||||
return [
|
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="llama-guard3:1b",
|
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
|
}
|
||||||
|
|
||||||
return safety_model_entries_map.get(provider_type, [])
|
return safety_model_entries_map.get(provider_type, [])
|
||||||
|
|
||||||
|
|
@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
|
||||||
|
|
||||||
|
|
||||||
# build a list of shields for all possible providers
|
# build a list of shields for all possible providers
|
||||||
def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]:
|
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
|
||||||
shields = []
|
available_models = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
provider_type = provider.provider_type.split("::")[1]
|
provider_type = provider.provider_type.split("::")[1]
|
||||||
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
||||||
if len(safety_model_entries) == 0:
|
if len(safety_model_entries) == 0:
|
||||||
continue
|
continue
|
||||||
if provider.provider_id:
|
|
||||||
shield_id = provider.provider_id
|
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||||
else:
|
provider_id = f"${{env.{env_var}:=__disabled__}}"
|
||||||
raise ValueError(f"Provider {provider.provider_type} has no provider_id")
|
|
||||||
for safety_model_entry in safety_model_entries:
|
available_models[provider_id] = safety_model_entries
|
||||||
print(f"provider.provider_id: {provider.provider_id}")
|
|
||||||
print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}")
|
return available_models
|
||||||
shields.append(
|
|
||||||
ShieldInput(
|
|
||||||
provider_id="llama-guard",
|
|
||||||
shield_id=shield_id,
|
|
||||||
provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return shields
|
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
|
@ -300,6 +241,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
||||||
provider_type="remote::pgvector",
|
provider_type="remote::pgvector",
|
||||||
config=PGVectorVectorIOConfig.sample_run_config(
|
config=PGVectorVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}",
|
||||||
db="${env.PGVECTOR_DB:=}",
|
db="${env.PGVECTOR_DB:=}",
|
||||||
user="${env.PGVECTOR_USER:=}",
|
user="${env.PGVECTOR_USER:=}",
|
||||||
password="${env.PGVECTOR_PASSWORD:=}",
|
password="${env.PGVECTOR_PASSWORD:=}",
|
||||||
|
|
@ -307,8 +249,6 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
shields = get_shields_for_providers(remote_inference_providers)
|
|
||||||
|
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
|
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
|
||||||
"vector_io": ([p.provider_type for p in vector_io_providers]),
|
"vector_io": ([p.provider_type for p in vector_io_providers]),
|
||||||
|
|
@ -361,7 +301,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models, ids_conflict_in_models = get_model_registry(available_models)
|
||||||
|
|
||||||
|
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
|
||||||
|
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
|
||||||
|
|
||||||
def get_model_registry(
|
def get_model_registry(
|
||||||
available_models: dict[str, list[ProviderModelEntry]],
|
available_models: dict[str, list[ProviderModelEntry]],
|
||||||
) -> list[ModelInput]:
|
) -> tuple[list[ModelInput], bool]:
|
||||||
models = []
|
models = []
|
||||||
|
|
||||||
# check for conflicts in model ids
|
# check for conflicts in model ids
|
||||||
|
|
@ -74,7 +74,50 @@ def get_model_registry(
|
||||||
metadata=entry.metadata,
|
metadata=entry.metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return models
|
return models, ids_conflict
|
||||||
|
|
||||||
|
|
||||||
|
def get_shield_registry(
|
||||||
|
available_safety_models: dict[str, list[ProviderModelEntry]],
|
||||||
|
ids_conflict_in_models: bool,
|
||||||
|
) -> list[ShieldInput]:
|
||||||
|
shields = []
|
||||||
|
|
||||||
|
# check for conflicts in shield ids
|
||||||
|
all_ids = set()
|
||||||
|
ids_conflict = False
|
||||||
|
|
||||||
|
for _, entries in available_safety_models.items():
|
||||||
|
for entry in entries:
|
||||||
|
ids = [entry.provider_model_id] + entry.aliases
|
||||||
|
for model_id in ids:
|
||||||
|
if model_id in all_ids:
|
||||||
|
ids_conflict = True
|
||||||
|
rich.print(
|
||||||
|
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
all_ids.update(ids)
|
||||||
|
if ids_conflict:
|
||||||
|
break
|
||||||
|
if ids_conflict:
|
||||||
|
break
|
||||||
|
|
||||||
|
for provider_id, entries in available_safety_models.items():
|
||||||
|
for entry in entries:
|
||||||
|
ids = [entry.provider_model_id] + entry.aliases
|
||||||
|
for model_id in ids:
|
||||||
|
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
||||||
|
shields.append(
|
||||||
|
ShieldInput(
|
||||||
|
shield_id=identifier,
|
||||||
|
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
|
||||||
|
if ids_conflict_in_models
|
||||||
|
else entry.provider_model_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return shields
|
||||||
|
|
||||||
|
|
||||||
class DefaultModel(BaseModel):
|
class DefaultModel(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models, _ = get_model_registry(available_models)
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="watsonx",
|
name="watsonx",
|
||||||
distro_type="remote_hosted",
|
distro_type="remote_hosted",
|
||||||
|
|
|
||||||
82
llama_stack/ui/app/logs/vector-stores/[id]/page.tsx
Normal file
82
llama_stack/ui/app/logs/vector-stores/[id]/page.tsx
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { useParams, useRouter } from "next/navigation";
|
||||||
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
|
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||||
|
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||||
|
import { VectorStoreDetailView } from "@/components/vector-stores/vector-store-detail";
|
||||||
|
|
||||||
|
export default function VectorStoreDetailPage() {
|
||||||
|
const params = useParams();
|
||||||
|
const id = params.id as string;
|
||||||
|
const client = useAuthClient();
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
const [store, setStore] = useState<VectorStore | null>(null);
|
||||||
|
const [files, setFiles] = useState<VectorStoreFile[]>([]);
|
||||||
|
const [isLoadingStore, setIsLoadingStore] = useState(true);
|
||||||
|
const [isLoadingFiles, setIsLoadingFiles] = useState(true);
|
||||||
|
const [errorStore, setErrorStore] = useState<Error | null>(null);
|
||||||
|
const [errorFiles, setErrorFiles] = useState<Error | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!id) {
|
||||||
|
setErrorStore(new Error("Vector Store ID is missing."));
|
||||||
|
setIsLoadingStore(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const fetchStore = async () => {
|
||||||
|
setIsLoadingStore(true);
|
||||||
|
setErrorStore(null);
|
||||||
|
try {
|
||||||
|
const response = await client.vectorStores.retrieve(id);
|
||||||
|
setStore(response as VectorStore);
|
||||||
|
} catch (err) {
|
||||||
|
setErrorStore(
|
||||||
|
err instanceof Error
|
||||||
|
? err
|
||||||
|
: new Error("Failed to load vector store."),
|
||||||
|
);
|
||||||
|
} finally {
|
||||||
|
setIsLoadingStore(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
fetchStore();
|
||||||
|
}, [id, client]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!id) {
|
||||||
|
setErrorFiles(new Error("Vector Store ID is missing."));
|
||||||
|
setIsLoadingFiles(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const fetchFiles = async () => {
|
||||||
|
setIsLoadingFiles(true);
|
||||||
|
setErrorFiles(null);
|
||||||
|
try {
|
||||||
|
const result = await client.vectorStores.files.list(id as any);
|
||||||
|
setFiles((result as any).data);
|
||||||
|
} catch (err) {
|
||||||
|
setErrorFiles(
|
||||||
|
err instanceof Error ? err : new Error("Failed to load files."),
|
||||||
|
);
|
||||||
|
} finally {
|
||||||
|
setIsLoadingFiles(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
fetchFiles();
|
||||||
|
}, [id]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<VectorStoreDetailView
|
||||||
|
store={store}
|
||||||
|
files={files}
|
||||||
|
isLoadingStore={isLoadingStore}
|
||||||
|
isLoadingFiles={isLoadingFiles}
|
||||||
|
errorStore={errorStore}
|
||||||
|
errorFiles={errorFiles}
|
||||||
|
id={id}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
16
llama_stack/ui/app/logs/vector-stores/layout.tsx
Normal file
16
llama_stack/ui/app/logs/vector-stores/layout.tsx
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import React from "react";
|
||||||
|
import LogsLayout from "@/components/layout/logs-layout";
|
||||||
|
|
||||||
|
export default function VectorStoresLayout({
|
||||||
|
children,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<LogsLayout sectionLabel="Vector Stores" basePath="/logs/vector-stores">
|
||||||
|
{children}
|
||||||
|
</LogsLayout>
|
||||||
|
);
|
||||||
|
}
|
||||||
121
llama_stack/ui/app/logs/vector-stores/page.tsx
Normal file
121
llama_stack/ui/app/logs/vector-stores/page.tsx
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import React from "react";
|
||||||
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
|
import type {
|
||||||
|
ListVectorStoresResponse,
|
||||||
|
VectorStore,
|
||||||
|
} from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
import { usePagination } from "@/hooks/use-pagination";
|
||||||
|
import {
|
||||||
|
Table,
|
||||||
|
TableBody,
|
||||||
|
TableCaption,
|
||||||
|
TableCell,
|
||||||
|
TableHead,
|
||||||
|
TableHeader,
|
||||||
|
TableRow,
|
||||||
|
} from "@/components/ui/table";
|
||||||
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
|
|
||||||
|
export default function VectorStoresPage() {
|
||||||
|
const client = useAuthClient();
|
||||||
|
const router = useRouter();
|
||||||
|
const {
|
||||||
|
data: stores,
|
||||||
|
status,
|
||||||
|
hasMore,
|
||||||
|
error,
|
||||||
|
loadMore,
|
||||||
|
} = usePagination<VectorStore>({
|
||||||
|
limit: 20,
|
||||||
|
order: "desc",
|
||||||
|
fetchFunction: async (client, params) => {
|
||||||
|
const response = await client.vectorStores.list({
|
||||||
|
after: params.after,
|
||||||
|
limit: params.limit,
|
||||||
|
order: params.order,
|
||||||
|
} as any);
|
||||||
|
return response as ListVectorStoresResponse;
|
||||||
|
},
|
||||||
|
errorMessagePrefix: "vector stores",
|
||||||
|
});
|
||||||
|
|
||||||
|
// Auto-load all pages for infinite scroll behavior (like Responses)
|
||||||
|
React.useEffect(() => {
|
||||||
|
if (status === "idle" && hasMore) {
|
||||||
|
loadMore();
|
||||||
|
}
|
||||||
|
}, [status, hasMore, loadMore]);
|
||||||
|
|
||||||
|
if (status === "loading") {
|
||||||
|
return (
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Skeleton className="h-8 w-full" />
|
||||||
|
<Skeleton className="h-4 w-full" />
|
||||||
|
<Skeleton className="h-4 w-full" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status === "error") {
|
||||||
|
return <div className="text-destructive">Error: {error?.message}</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!stores || stores.length === 0) {
|
||||||
|
return <p>No vector stores found.</p>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="overflow-auto flex-1 min-h-0">
|
||||||
|
<Table>
|
||||||
|
<TableHeader>
|
||||||
|
<TableRow>
|
||||||
|
<TableHead>ID</TableHead>
|
||||||
|
<TableHead>Name</TableHead>
|
||||||
|
<TableHead>Created</TableHead>
|
||||||
|
<TableHead>Completed</TableHead>
|
||||||
|
<TableHead>Cancelled</TableHead>
|
||||||
|
<TableHead>Failed</TableHead>
|
||||||
|
<TableHead>In Progress</TableHead>
|
||||||
|
<TableHead>Total</TableHead>
|
||||||
|
<TableHead>Usage Bytes</TableHead>
|
||||||
|
<TableHead>Provider ID</TableHead>
|
||||||
|
<TableHead>Provider Vector DB ID</TableHead>
|
||||||
|
</TableRow>
|
||||||
|
</TableHeader>
|
||||||
|
<TableBody>
|
||||||
|
{stores.map((store) => {
|
||||||
|
const fileCounts = store.file_counts;
|
||||||
|
const metadata = store.metadata || {};
|
||||||
|
const providerId = metadata.provider_id ?? "";
|
||||||
|
const providerDbId = metadata.provider_vector_db_id ?? "";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<TableRow
|
||||||
|
key={store.id}
|
||||||
|
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
|
||||||
|
className="cursor-pointer hover:bg-muted/50"
|
||||||
|
>
|
||||||
|
<TableCell>{store.id}</TableCell>
|
||||||
|
<TableCell>{store.name}</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
{new Date(store.created_at * 1000).toLocaleString()}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>{fileCounts.completed}</TableCell>
|
||||||
|
<TableCell>{fileCounts.cancelled}</TableCell>
|
||||||
|
<TableCell>{fileCounts.failed}</TableCell>
|
||||||
|
<TableCell>{fileCounts.in_progress}</TableCell>
|
||||||
|
<TableCell>{fileCounts.total}</TableCell>
|
||||||
|
<TableCell>{store.usage_bytes}</TableCell>
|
||||||
|
<TableCell>{providerId}</TableCell>
|
||||||
|
<TableCell>{providerDbId}</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,11 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { MessageSquareText, MessagesSquare, MoveUpRight } from "lucide-react";
|
import {
|
||||||
|
MessageSquareText,
|
||||||
|
MessagesSquare,
|
||||||
|
MoveUpRight,
|
||||||
|
Database,
|
||||||
|
} from "lucide-react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { usePathname } from "next/navigation";
|
import { usePathname } from "next/navigation";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
@ -28,6 +33,11 @@ const logItems = [
|
||||||
url: "/logs/responses",
|
url: "/logs/responses",
|
||||||
icon: MessagesSquare,
|
icon: MessagesSquare,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
title: "Vector Stores",
|
||||||
|
url: "/logs/vector-stores",
|
||||||
|
icon: Database,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
title: "Documentation",
|
title: "Documentation",
|
||||||
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
|
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
|
||||||
|
|
@ -57,13 +67,13 @@ export function AppSidebar() {
|
||||||
className={cn(
|
className={cn(
|
||||||
"justify-start",
|
"justify-start",
|
||||||
isActive &&
|
isActive &&
|
||||||
"bg-gray-200 hover:bg-gray-200 text-primary hover:text-primary",
|
"bg-gray-200 dark:bg-gray-700 hover:bg-gray-200 dark:hover:bg-gray-700 text-gray-900 dark:text-gray-100",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<Link href={item.url}>
|
<Link href={item.url}>
|
||||||
<item.icon
|
<item.icon
|
||||||
className={cn(
|
className={cn(
|
||||||
isActive && "text-primary",
|
isActive && "text-gray-900 dark:text-gray-100",
|
||||||
"mr-2 h-4 w-4",
|
"mr-2 h-4 w-4",
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,9 @@ export function PropertyItem({
|
||||||
>
|
>
|
||||||
<strong>{label}:</strong>{" "}
|
<strong>{label}:</strong>{" "}
|
||||||
{typeof value === "string" || typeof value === "number" ? (
|
{typeof value === "string" || typeof value === "number" ? (
|
||||||
<span className="text-gray-900 font-medium">{value}</span>
|
<span className="text-gray-900 dark:text-gray-100 font-medium">
|
||||||
|
{value}
|
||||||
|
</span>
|
||||||
) : (
|
) : (
|
||||||
value
|
value
|
||||||
)}
|
)}
|
||||||
|
|
@ -112,7 +114,9 @@ export function PropertiesCard({ children }: PropertiesCardProps) {
|
||||||
<CardTitle>Properties</CardTitle>
|
<CardTitle>Properties</CardTitle>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
<ul className="space-y-2 text-sm text-gray-600">{children}</ul>
|
<ul className="space-y-2 text-sm text-gray-600 dark:text-gray-400">
|
||||||
|
{children}
|
||||||
|
</ul>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,10 @@ export const MessageBlock: React.FC<MessageBlockProps> = ({
|
||||||
}) => {
|
}) => {
|
||||||
return (
|
return (
|
||||||
<div className={`mb-4 ${className}`}>
|
<div className={`mb-4 ${className}`}>
|
||||||
<p className="py-1 font-semibold text-gray-800 mb-1">
|
<p className="py-1 font-semibold text-muted-foreground mb-1">
|
||||||
{label}
|
{label}
|
||||||
{labelDetail && (
|
{labelDetail && (
|
||||||
<span className="text-xs text-gray-500 font-normal ml-1">
|
<span className="text-xs text-muted-foreground font-normal ml-1">
|
||||||
{labelDetail}
|
{labelDetail}
|
||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
|
|
|
||||||
128
llama_stack/ui/components/vector-stores/vector-store-detail.tsx
Normal file
128
llama_stack/ui/components/vector-stores/vector-store-detail.tsx
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||||
|
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||||
|
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||||
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
|
import {
|
||||||
|
DetailLoadingView,
|
||||||
|
DetailErrorView,
|
||||||
|
DetailNotFoundView,
|
||||||
|
DetailLayout,
|
||||||
|
PropertiesCard,
|
||||||
|
PropertyItem,
|
||||||
|
} from "@/components/layout/detail-layout";
|
||||||
|
import {
|
||||||
|
Table,
|
||||||
|
TableBody,
|
||||||
|
TableCaption,
|
||||||
|
TableCell,
|
||||||
|
TableHead,
|
||||||
|
TableHeader,
|
||||||
|
TableRow,
|
||||||
|
} from "@/components/ui/table";
|
||||||
|
|
||||||
|
interface VectorStoreDetailViewProps {
|
||||||
|
store: VectorStore | null;
|
||||||
|
files: VectorStoreFile[];
|
||||||
|
isLoadingStore: boolean;
|
||||||
|
isLoadingFiles: boolean;
|
||||||
|
errorStore: Error | null;
|
||||||
|
errorFiles: Error | null;
|
||||||
|
id: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function VectorStoreDetailView({
|
||||||
|
store,
|
||||||
|
files,
|
||||||
|
isLoadingStore,
|
||||||
|
isLoadingFiles,
|
||||||
|
errorStore,
|
||||||
|
errorFiles,
|
||||||
|
id,
|
||||||
|
}: VectorStoreDetailViewProps) {
|
||||||
|
const title = "Vector Store Details";
|
||||||
|
|
||||||
|
if (errorStore) {
|
||||||
|
return <DetailErrorView title={title} id={id} error={errorStore} />;
|
||||||
|
}
|
||||||
|
if (isLoadingStore) {
|
||||||
|
return <DetailLoadingView title={title} />;
|
||||||
|
}
|
||||||
|
if (!store) {
|
||||||
|
return <DetailNotFoundView title={title} id={id} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
const mainContent = (
|
||||||
|
<>
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle>Files</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
{isLoadingFiles ? (
|
||||||
|
<Skeleton className="h-4 w-full" />
|
||||||
|
) : errorFiles ? (
|
||||||
|
<div className="text-destructive text-sm">
|
||||||
|
Error loading files: {errorFiles.message}
|
||||||
|
</div>
|
||||||
|
) : files.length > 0 ? (
|
||||||
|
<Table>
|
||||||
|
<TableCaption>Files in this vector store</TableCaption>
|
||||||
|
<TableHeader>
|
||||||
|
<TableRow>
|
||||||
|
<TableHead>ID</TableHead>
|
||||||
|
<TableHead>Status</TableHead>
|
||||||
|
<TableHead>Created</TableHead>
|
||||||
|
<TableHead>Usage Bytes</TableHead>
|
||||||
|
</TableRow>
|
||||||
|
</TableHeader>
|
||||||
|
<TableBody>
|
||||||
|
{files.map((file) => (
|
||||||
|
<TableRow key={file.id}>
|
||||||
|
<TableCell>{file.id}</TableCell>
|
||||||
|
<TableCell>{file.status}</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
{new Date(file.created_at * 1000).toLocaleString()}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>{file.usage_bytes}</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
))}
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
) : (
|
||||||
|
<p className="text-gray-500 italic text-sm">
|
||||||
|
No files in this vector store.
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
|
||||||
|
const sidebar = (
|
||||||
|
<PropertiesCard>
|
||||||
|
<PropertyItem label="ID" value={store.id} />
|
||||||
|
<PropertyItem label="Name" value={store.name || ""} />
|
||||||
|
<PropertyItem
|
||||||
|
label="Created"
|
||||||
|
value={new Date(store.created_at * 1000).toLocaleString()}
|
||||||
|
/>
|
||||||
|
<PropertyItem label="Status" value={store.status} />
|
||||||
|
<PropertyItem label="Total Files" value={store.file_counts.total} />
|
||||||
|
<PropertyItem label="Usage Bytes" value={store.usage_bytes} />
|
||||||
|
<PropertyItem
|
||||||
|
label="Provider ID"
|
||||||
|
value={(store.metadata.provider_id as string) || ""}
|
||||||
|
/>
|
||||||
|
<PropertyItem
|
||||||
|
label="Provider DB ID"
|
||||||
|
value={(store.metadata.provider_vector_db_id as string) || ""}
|
||||||
|
/>
|
||||||
|
</PropertiesCard>
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
|
||||||
|
);
|
||||||
|
}
|
||||||
474
llama_stack/ui/package-lock.json
generated
474
llama_stack/ui/package-lock.json
generated
|
|
@ -15,7 +15,7 @@
|
||||||
"@radix-ui/react-tooltip": "^1.2.6",
|
"@radix-ui/react-tooltip": "^1.2.6",
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"llama-stack-client": "0.2.13",
|
"llama-stack-client": "^0.2.14",
|
||||||
"lucide-react": "^0.510.0",
|
"lucide-react": "^0.510.0",
|
||||||
"next": "15.3.3",
|
"next": "15.3.3",
|
||||||
"next-auth": "^4.24.11",
|
"next-auth": "^4.24.11",
|
||||||
|
|
@ -676,406 +676,6 @@
|
||||||
"tslib": "^2.4.0"
|
"tslib": "^2.4.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@esbuild/aix-ppc64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-9o3TMmpmftaCMepOdA5k/yDw8SfInyzWWTjYTFCX3kPSDJMROQTb8jg+h9Cnwnmm1vOzvxN7gIfB5V2ewpjtGA==",
|
|
||||||
"cpu": [
|
|
||||||
"ppc64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"aix"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/android-arm": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-AdJKSPeEHgi7/ZhuIPtcQKr5RQdo6OO2IL87JkianiMYMPbCtot9fxPbrMiBADOWWm3T2si9stAiVsGbTQFkbA==",
|
|
||||||
"cpu": [
|
|
||||||
"arm"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"android"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/android-arm64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-VGzGhj4lJO+TVGV1v8ntCZWJktV7SGCs3Pn1GRWI1SBFtRALoomm8k5E9Pmwg3HOAal2VDc2F9+PM/rEY6oIDg==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"android"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/android-x64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-D2GyJT1kjvO//drbRT3Hib9XPwQeWd9vZoBJn+bu/lVsOZ13cqNdDeqIF/xQ5/VmWvMduP6AmXvylO/PIc2isw==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"android"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/darwin-arm64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-GtaBgammVvdF7aPIgH2jxMDdivezgFu6iKpmT+48+F8Hhg5J/sfnDieg0aeG/jfSvkYQU2/pceFPDKlqZzwnfQ==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"darwin"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/darwin-x64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-1iT4FVL0dJ76/q1wd7XDsXrSW+oLoquptvh4CLR4kITDtqi2e/xwXwdCVH8hVHU43wgJdsq7Gxuzcs6Iq/7bxQ==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"darwin"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/freebsd-arm64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-nk4tGP3JThz4La38Uy/gzyXtpkPW8zSAmoUhK9xKKXdBCzKODMc2adkB2+8om9BDYugz+uGV7sLmpTYzvmz6Sw==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"freebsd"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/freebsd-x64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-PrikaNjiXdR2laW6OIjlbeuCPrPaAl0IwPIaRv+SMV8CiM8i2LqVUHFC1+8eORgWyY7yhQY+2U2fA55mBzReaw==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"freebsd"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/linux-arm": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-cPzojwW2okgh7ZlRpcBEtsX7WBuqbLrNXqLU89GxWbNt6uIg78ET82qifUy3W6OVww6ZWobWub5oqZOVtwolfw==",
|
|
||||||
"cpu": [
|
|
||||||
"arm"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/linux-arm64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-Z9kfb1v6ZlGbWj8EJk9T6czVEjjq2ntSYLY2cw6pAZl4oKtfgQuS4HOq41M/BcoLPzrUbNd+R4BXFyH//nHxVg==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/linux-ia32": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-sQ7l00M8bSv36GLV95BVAdhJ2QsIbCuCjh/uYrWiMQSUuV+LpXwIqhgJDcvMTj+VsQmqAHL2yYaasENvJ7CDKA==",
|
|
||||||
"cpu": [
|
|
||||||
"ia32"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/linux-loong64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-0ur7ae16hDUC4OL5iEnDb0tZHDxYmuQyhKhsPBV8f99f6Z9KQM02g33f93rNH5A30agMS46u2HP6qTdEt6Q1kg==",
|
|
||||||
"cpu": [
|
|
||||||
"loong64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/linux-mips64el": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-kB/66P1OsHO5zLz0i6X0RxlQ+3cu0mkxS3TKFvkb5lin6uwZ/ttOkP3Z8lfR9mJOBk14ZwZ9182SIIWFGNmqmg==",
|
|
||||||
"cpu": [
|
|
||||||
"mips64el"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/linux-ppc64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-UZCmJ7r9X2fe2D6jBmkLBMQetXPXIsZjQJCjgwpVDz+YMcS6oFR27alkgGv3Oqkv07bxdvw7fyB71/olceJhkQ==",
|
|
||||||
"cpu": [
|
|
||||||
"ppc64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/linux-riscv64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-kTxwu4mLyeOlsVIFPfQo+fQJAV9mh24xL+y+Bm6ej067sYANjyEw1dNHmvoqxJUCMnkBdKpvOn0Ahql6+4VyeA==",
|
|
||||||
"cpu": [
|
|
||||||
"riscv64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/linux-s390x": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-K2dSKTKfmdh78uJ3NcWFiqyRrimfdinS5ErLSn3vluHNeHVnBAFWC8a4X5N+7FgVE1EjXS1QDZbpqZBjfrqMTQ==",
|
|
||||||
"cpu": [
|
|
||||||
"s390x"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/linux-x64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-uhj8N2obKTE6pSZ+aMUbqq+1nXxNjZIIjCjGLfsWvVpy7gKCOL6rsY1MhRh9zLtUtAI7vpgLMK6DxjO8Qm9lJw==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/netbsd-arm64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-pwHtMP9viAy1oHPvgxtOv+OkduK5ugofNTVDilIzBLpoWAM16r7b/mxBvfpuQDpRQFMfuVr5aLcn4yveGvBZvw==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"netbsd"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/netbsd-x64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-WOb5fKrvVTRMfWFNCroYWWklbnXH0Q5rZppjq0vQIdlsQKuw6mdSihwSo4RV/YdQ5UCKKvBy7/0ZZYLBZKIbwQ==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"netbsd"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/openbsd-arm64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-7A208+uQKgTxHd0G0uqZO8UjK2R0DDb4fDmERtARjSHWxqMTye4Erz4zZafx7Di9Cv+lNHYuncAkiGFySoD+Mw==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"openbsd"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/openbsd-x64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-G4hE405ErTWraiZ8UiSoesH8DaCsMm0Cay4fsFWOOUcz8b8rC6uCvnagr+gnioEjWn0wC+o1/TAHt+It+MpIMg==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"openbsd"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/sunos-x64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-l+azKShMy7FxzY0Rj4RCt5VD/q8mG/e+mDivgspo+yL8zW7qEwctQ6YqKX34DTEleFAvCIUviCFX1SDZRSyMQA==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"sunos"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/win32-arm64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-O2S7SNZzdcFG7eFKgvwUEZ2VG9D/sn/eIiz8XRZ1Q/DO5a3s76Xv0mdBzVM5j5R639lXQmPmSo0iRpHqUUrsxw==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"win32"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/win32-ia32": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-onOJ02pqs9h1iMJ1PQphR+VZv8qBMQ77Klcsqv9CNW2w6yLqoURLcgERAIurY6QE63bbLuqgP9ATqajFLK5AMQ==",
|
|
||||||
"cpu": [
|
|
||||||
"ia32"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"win32"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@esbuild/win32-x64": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-TXv6YnJ8ZMVdX+SXWVBo/0p8LTcrUYngpWjvm91TMjjBQii7Oz11Lw5lbDV5Y0TzuhSJHwiH4hEtC1I42mMS0g==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"license": "MIT",
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"win32"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@eslint-community/eslint-utils": {
|
"node_modules/@eslint-community/eslint-utils": {
|
||||||
"version": "4.7.0",
|
"version": "4.7.0",
|
||||||
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz",
|
||||||
|
|
@ -5999,46 +5599,6 @@
|
||||||
"url": "https://github.com/sponsors/ljharb"
|
"url": "https://github.com/sponsors/ljharb"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/esbuild": {
|
|
||||||
"version": "0.25.5",
|
|
||||||
"resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.5.tgz",
|
|
||||||
"integrity": "sha512-P8OtKZRv/5J5hhz0cUAdu/cLuPIKXpQl1R9pZtvmHWQvrAUVd0UNIPT4IB4W3rNOqVO0rlqHmCIbSwxh/c9yUQ==",
|
|
||||||
"hasInstallScript": true,
|
|
||||||
"license": "MIT",
|
|
||||||
"bin": {
|
|
||||||
"esbuild": "bin/esbuild"
|
|
||||||
},
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18"
|
|
||||||
},
|
|
||||||
"optionalDependencies": {
|
|
||||||
"@esbuild/aix-ppc64": "0.25.5",
|
|
||||||
"@esbuild/android-arm": "0.25.5",
|
|
||||||
"@esbuild/android-arm64": "0.25.5",
|
|
||||||
"@esbuild/android-x64": "0.25.5",
|
|
||||||
"@esbuild/darwin-arm64": "0.25.5",
|
|
||||||
"@esbuild/darwin-x64": "0.25.5",
|
|
||||||
"@esbuild/freebsd-arm64": "0.25.5",
|
|
||||||
"@esbuild/freebsd-x64": "0.25.5",
|
|
||||||
"@esbuild/linux-arm": "0.25.5",
|
|
||||||
"@esbuild/linux-arm64": "0.25.5",
|
|
||||||
"@esbuild/linux-ia32": "0.25.5",
|
|
||||||
"@esbuild/linux-loong64": "0.25.5",
|
|
||||||
"@esbuild/linux-mips64el": "0.25.5",
|
|
||||||
"@esbuild/linux-ppc64": "0.25.5",
|
|
||||||
"@esbuild/linux-riscv64": "0.25.5",
|
|
||||||
"@esbuild/linux-s390x": "0.25.5",
|
|
||||||
"@esbuild/linux-x64": "0.25.5",
|
|
||||||
"@esbuild/netbsd-arm64": "0.25.5",
|
|
||||||
"@esbuild/netbsd-x64": "0.25.5",
|
|
||||||
"@esbuild/openbsd-arm64": "0.25.5",
|
|
||||||
"@esbuild/openbsd-x64": "0.25.5",
|
|
||||||
"@esbuild/sunos-x64": "0.25.5",
|
|
||||||
"@esbuild/win32-arm64": "0.25.5",
|
|
||||||
"@esbuild/win32-ia32": "0.25.5",
|
|
||||||
"@esbuild/win32-x64": "0.25.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/escalade": {
|
"node_modules/escalade": {
|
||||||
"version": "3.2.0",
|
"version": "3.2.0",
|
||||||
"resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz",
|
"resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz",
|
||||||
|
|
@ -6993,6 +6553,7 @@
|
||||||
"version": "2.3.3",
|
"version": "2.3.3",
|
||||||
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz",
|
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz",
|
||||||
"integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==",
|
"integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==",
|
||||||
|
"dev": true,
|
||||||
"hasInstallScript": true,
|
"hasInstallScript": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"optional": true,
|
"optional": true,
|
||||||
|
|
@ -7154,6 +6715,7 @@
|
||||||
"version": "4.10.0",
|
"version": "4.10.0",
|
||||||
"resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.0.tgz",
|
"resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.0.tgz",
|
||||||
"integrity": "sha512-kGzZ3LWWQcGIAmg6iWvXn0ei6WDtV26wzHRMwDSzmAbcXrTEXxHy6IehI6/4eT6VRKyMP1eF1VqwrVUmE/LR7A==",
|
"integrity": "sha512-kGzZ3LWWQcGIAmg6iWvXn0ei6WDtV26wzHRMwDSzmAbcXrTEXxHy6IehI6/4eT6VRKyMP1eF1VqwrVUmE/LR7A==",
|
||||||
|
"dev": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"resolve-pkg-maps": "^1.0.0"
|
"resolve-pkg-maps": "^1.0.0"
|
||||||
|
|
@ -9537,9 +9099,10 @@
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
"node_modules/llama-stack-client": {
|
"node_modules/llama-stack-client": {
|
||||||
"version": "0.2.13",
|
"version": "0.2.14",
|
||||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.13.tgz",
|
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.14.tgz",
|
||||||
"integrity": "sha512-R1rTFLwgUimr+KjEUkzUvFL6vLASwS9qj3UDSVkJ5BmrKAs5GwVAMeL7yZaTBXGuPUVh124WSlC4d9H0FjWqLA==",
|
"integrity": "sha512-bVU3JHp+EPEKR0Vb9vcd9ZyQj/72jSDuptKLwOXET9WrkphIQ8xuW5ueecMTgq8UEls3lwB3HiZM2cDOR9eDsQ==",
|
||||||
|
"license": "Apache-2.0",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@types/node": "^18.11.18",
|
"@types/node": "^18.11.18",
|
||||||
"@types/node-fetch": "^2.6.4",
|
"@types/node-fetch": "^2.6.4",
|
||||||
|
|
@ -9547,8 +9110,7 @@
|
||||||
"agentkeepalive": "^4.2.1",
|
"agentkeepalive": "^4.2.1",
|
||||||
"form-data-encoder": "1.7.2",
|
"form-data-encoder": "1.7.2",
|
||||||
"formdata-node": "^4.3.2",
|
"formdata-node": "^4.3.2",
|
||||||
"node-fetch": "^2.6.7",
|
"node-fetch": "^2.6.7"
|
||||||
"tsx": "^4.19.2"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/llama-stack-client/node_modules/@types/node": {
|
"node_modules/llama-stack-client/node_modules/@types/node": {
|
||||||
|
|
@ -11148,6 +10710,7 @@
|
||||||
"version": "1.0.0",
|
"version": "1.0.0",
|
||||||
"resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz",
|
"resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz",
|
||||||
"integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==",
|
"integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==",
|
||||||
|
"dev": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"funding": {
|
"funding": {
|
||||||
"url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1"
|
"url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1"
|
||||||
|
|
@ -12198,25 +11761,6 @@
|
||||||
"integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==",
|
"integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==",
|
||||||
"license": "0BSD"
|
"license": "0BSD"
|
||||||
},
|
},
|
||||||
"node_modules/tsx": {
|
|
||||||
"version": "4.19.4",
|
|
||||||
"resolved": "https://registry.npmjs.org/tsx/-/tsx-4.19.4.tgz",
|
|
||||||
"integrity": "sha512-gK5GVzDkJK1SI1zwHf32Mqxf2tSJkNx+eYcNly5+nHvWqXUJYUkWBQtKauoESz3ymezAI++ZwT855x5p5eop+Q==",
|
|
||||||
"license": "MIT",
|
|
||||||
"dependencies": {
|
|
||||||
"esbuild": "~0.25.0",
|
|
||||||
"get-tsconfig": "^4.7.5"
|
|
||||||
},
|
|
||||||
"bin": {
|
|
||||||
"tsx": "dist/cli.mjs"
|
|
||||||
},
|
|
||||||
"engines": {
|
|
||||||
"node": ">=18.0.0"
|
|
||||||
},
|
|
||||||
"optionalDependencies": {
|
|
||||||
"fsevents": "~2.3.3"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/tw-animate-css": {
|
"node_modules/tw-animate-css": {
|
||||||
"version": "1.2.9",
|
"version": "1.2.9",
|
||||||
"resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.2.9.tgz",
|
"resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.2.9.tgz",
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
"@radix-ui/react-tooltip": "^1.2.6",
|
"@radix-ui/react-tooltip": "^1.2.6",
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"llama-stack-client": "0.2.13",
|
"llama-stack-client": "^0.2.14",
|
||||||
"lucide-react": "^0.510.0",
|
"lucide-react": "^0.510.0",
|
||||||
"next": "15.3.3",
|
"next": "15.3.3",
|
||||||
"next-auth": "^4.24.11",
|
"next-auth": "^4.24.11",
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,7 @@ dev = [
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
"pytest-html",
|
"pytest-html",
|
||||||
"pytest-json-report",
|
"pytest-json-report",
|
||||||
|
"pytest-socket", # For blocking network access in unit tests
|
||||||
"nbval", # For notebook testing
|
"nbval", # For notebook testing
|
||||||
"black",
|
"black",
|
||||||
"ruff",
|
"ruff",
|
||||||
|
|
@ -344,3 +345,6 @@ classmethod-decorators = ["classmethod", "pydantic.field_validator"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
markers = [
|
||||||
|
"allow_network: Allow network access for specific unit tests",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id):
|
||||||
return agent_config
|
return agent_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def agent_config_without_safety(text_model_id):
|
||||||
|
agent_config = dict(
|
||||||
|
model=text_model_id,
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
sampling_params={
|
||||||
|
"strategy": {
|
||||||
|
"type": "top_p",
|
||||||
|
"temperature": 0.0001,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tools=[],
|
||||||
|
enable_session_persistence=False,
|
||||||
|
)
|
||||||
|
return agent_config
|
||||||
|
|
||||||
|
|
||||||
def test_agent_simple(llama_stack_client, agent_config):
|
def test_agent_simple(llama_stack_client, agent_config):
|
||||||
agent = Agent(llama_stack_client, **agent_config)
|
agent = Agent(llama_stack_client, **agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
||||||
assert expected_kw in response.output_message.content.lower()
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
|
||||||
urls = ["llama3.rst", "lora_finetune.rst"]
|
urls = ["llama3.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
# passign as url
|
# passign as url
|
||||||
|
|
@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
||||||
metadata={},
|
metadata={},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
rag_agent = Agent(llama_stack_client, **agent_config)
|
rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
|
||||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||||
user_prompts = [
|
|
||||||
(
|
|
||||||
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
|
|
||||||
"grouped",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
(
|
(
|
||||||
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
||||||
|
|
@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
||||||
assert "lora" in response.output_message.content.lower()
|
assert "lora" in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
|
|
||||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
|
||||||
if "llama-4" in agent_config["model"].lower():
|
|
||||||
pytest.xfail("Not working for llama4")
|
|
||||||
|
|
||||||
documents = []
|
|
||||||
documents.append(
|
|
||||||
Document(
|
|
||||||
document_id="nba_wiki",
|
|
||||||
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
documents.append(
|
|
||||||
Document(
|
|
||||||
document_id="perplexity_wiki",
|
|
||||||
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
|
|
||||||
|
|
||||||
Srinivas, the CEO, worked at OpenAI as an AI researcher.
|
|
||||||
Konwinski was among the founding team at Databricks.
|
|
||||||
Yarats, the CTO, was an AI research scientist at Meta.
|
|
||||||
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
vector_db_id = f"test-vector-db-{uuid4()}"
|
|
||||||
llama_stack_client.vector_dbs.register(
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
embedding_dimension=384,
|
|
||||||
)
|
|
||||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
|
||||||
documents=documents,
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
chunk_size_in_tokens=128,
|
|
||||||
)
|
|
||||||
agent_config = {
|
|
||||||
**agent_config,
|
|
||||||
"tools": [
|
|
||||||
dict(
|
|
||||||
name="builtin::rag/knowledge_search",
|
|
||||||
args={"vector_db_ids": [vector_db_id]},
|
|
||||||
),
|
|
||||||
"builtin::code_interpreter",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
agent = Agent(llama_stack_client, **agent_config)
|
|
||||||
user_prompts = [
|
|
||||||
(
|
|
||||||
"when was Perplexity the company founded?",
|
|
||||||
[],
|
|
||||||
"knowledge_search",
|
|
||||||
"2022",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"when was the nba created?",
|
|
||||||
[],
|
|
||||||
"knowledge_search",
|
|
||||||
"1949",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
for prompt, docs, tool_name, expected_kw in user_prompts:
|
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
|
||||||
response = agent.create_turn(
|
|
||||||
messages=[{"role": "user", "content": prompt}],
|
|
||||||
session_id=session_id,
|
|
||||||
documents=docs,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
|
||||||
assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
|
|
||||||
if expected_kw:
|
|
||||||
assert expected_kw in response.output_message.content.lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"client_tools",
|
"client_tools",
|
||||||
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest_socket
|
||||||
|
|
||||||
# We need to import the fixtures here so that pytest can find them
|
# We need to import the fixtures here so that pytest can find them
|
||||||
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
|
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
|
||||||
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401
|
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_runtest_setup(item):
|
||||||
|
"""Setup for each test - check if network access should be allowed."""
|
||||||
|
if "allow_network" in item.keywords:
|
||||||
|
pytest_socket.enable_socket()
|
||||||
|
else:
|
||||||
|
# Allowing Unix sockets is necessary for some tests that use local servers and mocks
|
||||||
|
pytest_socket.disable_socket(allow_unix_socket=True)
|
||||||
|
|
|
||||||
|
|
@ -393,6 +393,7 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
|
||||||
assert chunks[0].event.event_type.value == "start"
|
assert chunks[0].event.event_type.value == "start"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.allow_network
|
||||||
def test_chat_completion_doesnt_block_event_loop(caplog):
|
def test_chat_completion_doesnt_block_event_loop(caplog):
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
loop.set_debug(True)
|
loop.set_debug(True)
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,37 @@ def helper(known_provider_model: ProviderModelEntry, known_provider_model2: Prov
|
||||||
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
||||||
|
|
||||||
|
|
||||||
|
class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
|
||||||
|
"""Test helper that simulates a provider with dynamically available models."""
|
||||||
|
|
||||||
|
def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]):
|
||||||
|
super().__init__(model_entries)
|
||||||
|
self._available_models = available_models
|
||||||
|
|
||||||
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
|
return model in self._available_models
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dynamic_model() -> Model:
|
||||||
|
"""A model that's not in static config but available dynamically."""
|
||||||
|
return Model(
|
||||||
|
provider_id="provider",
|
||||||
|
identifier="dynamic-model",
|
||||||
|
provider_resource_id="dynamic-provider-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def helper_with_dynamic_models(
|
||||||
|
known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry, dynamic_model: Model
|
||||||
|
) -> MockModelRegistryHelperWithDynamicModels:
|
||||||
|
"""Helper that includes dynamically available models."""
|
||||||
|
return MockModelRegistryHelperWithDynamicModels(
|
||||||
|
[known_provider_model, known_provider_model2], [dynamic_model.provider_resource_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||||
assert helper.get_provider_model_id(unknown_model.model_id) is None
|
assert helper.get_provider_model_id(unknown_model.model_id) is None
|
||||||
|
|
||||||
|
|
@ -151,3 +182,63 @@ async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_m
|
||||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||||
await helper.unregister_model(known_model.provider_resource_id)
|
await helper.unregister_model(known_model.provider_resource_id)
|
||||||
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_model_from_check_model_availability(
|
||||||
|
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
|
||||||
|
) -> None:
|
||||||
|
"""Test that models returned by check_model_availability can be registered."""
|
||||||
|
# Verify the model is not in static config
|
||||||
|
assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None
|
||||||
|
|
||||||
|
# But it should be available via check_model_availability
|
||||||
|
is_available = await helper_with_dynamic_models.check_model_availability(dynamic_model.provider_resource_id)
|
||||||
|
assert is_available
|
||||||
|
|
||||||
|
# Registration should succeed
|
||||||
|
registered_model = await helper_with_dynamic_models.register_model(dynamic_model)
|
||||||
|
assert registered_model == dynamic_model
|
||||||
|
|
||||||
|
# Model should now be registered and accessible
|
||||||
|
assert (
|
||||||
|
helper_with_dynamic_models.get_provider_model_id(dynamic_model.model_id) == dynamic_model.provider_resource_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_model_not_in_static_or_dynamic(
|
||||||
|
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model
|
||||||
|
) -> None:
|
||||||
|
"""Test that models not in static config or dynamic models are rejected."""
|
||||||
|
# Verify the model is not in static config
|
||||||
|
assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None
|
||||||
|
|
||||||
|
# And not available via check_model_availability
|
||||||
|
is_available = await helper_with_dynamic_models.check_model_availability(unknown_model.provider_resource_id)
|
||||||
|
assert not is_available
|
||||||
|
|
||||||
|
# Registration should fail with comprehensive error message
|
||||||
|
with pytest.raises(Exception) as exc_info: # UnsupportedModelError
|
||||||
|
await helper_with_dynamic_models.register_model(unknown_model)
|
||||||
|
|
||||||
|
# Error should include static models and "..." for dynamic models
|
||||||
|
error_str = str(exc_info.value)
|
||||||
|
assert "..." in error_str # "..." should be in error message
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_alias_for_dynamic_model(
|
||||||
|
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
|
||||||
|
) -> None:
|
||||||
|
"""Test that we can register an alias that maps to a dynamically available model."""
|
||||||
|
# Create a model with a different identifier but same provider_resource_id
|
||||||
|
alias_model = Model(
|
||||||
|
provider_id=dynamic_model.provider_id,
|
||||||
|
identifier="dynamic-model-alias",
|
||||||
|
provider_resource_id=dynamic_model.provider_resource_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Registration should succeed since the provider_resource_id is available dynamically
|
||||||
|
registered_model = await helper_with_dynamic_models.register_model(alias_model)
|
||||||
|
assert registered_model == alias_model
|
||||||
|
|
||||||
|
# Both the original provider_resource_id and the new alias should work
|
||||||
|
assert helper_with_dynamic_models.get_provider_model_id(alias_model.model_id) == dynamic_model.provider_resource_id
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ from pymilvus import MilvusClient, connections
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||||
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
|
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||||
|
|
@ -90,7 +92,7 @@ def sample_embeddings_with_metadata(sample_chunks_with_metadata):
|
||||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
|
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["milvus", "sqlite_vec"])
|
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss"])
|
||||||
def vector_provider(request):
|
def vector_provider(request):
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
@ -116,7 +118,7 @@ async def unique_kvstore_config(tmp_path_factory):
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def sqlite_vec_db_path(tmp_path_factory):
|
def sqlite_vec_db_path(tmp_path_factory):
|
||||||
db_path = str(tmp_path_factory.getbasetemp() / "test.db")
|
db_path = str(tmp_path_factory.getbasetemp() / "test_sqlite_vec.db")
|
||||||
return db_path
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -198,11 +200,49 @@ async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
||||||
await adapter.shutdown()
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def faiss_vec_db_path(tmp_path_factory):
|
||||||
|
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def faiss_vec_index(embedding_dimension):
|
||||||
|
index = FaissIndex(embedding_dimension)
|
||||||
|
yield index
|
||||||
|
await index.delete()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||||
|
config = FaissVectorIOConfig(
|
||||||
|
kvstore=unique_kvstore_config,
|
||||||
|
)
|
||||||
|
adapter = FaissVectorIOAdapter(
|
||||||
|
config=config,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
await adapter.initialize()
|
||||||
|
await adapter.register_vector_db(
|
||||||
|
VectorDB(
|
||||||
|
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
|
||||||
|
provider_id="test_provider",
|
||||||
|
embedding_model="test_model",
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield adapter
|
||||||
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vector_io_adapter(vector_provider, request):
|
def vector_io_adapter(vector_provider, request):
|
||||||
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
||||||
if vector_provider == "milvus":
|
if vector_provider == "milvus":
|
||||||
return request.getfixturevalue("milvus_vec_adapter")
|
return request.getfixturevalue("milvus_vec_adapter")
|
||||||
|
elif vector_provider == "faiss":
|
||||||
|
return request.getfixturevalue("faiss_vec_adapter")
|
||||||
else:
|
else:
|
||||||
return request.getfixturevalue("sqlite_vec_adapter")
|
return request.getfixturevalue("sqlite_vec_adapter")
|
||||||
|
|
||||||
|
|
|
||||||
191
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
191
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
|
||||||
|
# Mock the entire pymilvus module
|
||||||
|
pymilvus_mock = MagicMock()
|
||||||
|
pymilvus_mock.DataType = MagicMock()
|
||||||
|
pymilvus_mock.MilvusClient = MagicMock
|
||||||
|
|
||||||
|
# Apply the mock before importing MilvusIndex
|
||||||
|
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||||
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||||
|
|
||||||
|
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||||
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||||
|
# tests/integration/vector_io/
|
||||||
|
#
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||||
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||||
|
|
||||||
|
MILVUS_PROVIDER = "milvus"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_milvus_client() -> MagicMock:
|
||||||
|
"""Create a mock Milvus client with common method behaviors."""
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
# Mock collection operations
|
||||||
|
client.has_collection.return_value = False # Initially no collection
|
||||||
|
client.create_collection.return_value = None
|
||||||
|
client.drop_collection.return_value = None
|
||||||
|
|
||||||
|
# Mock insert operation
|
||||||
|
client.insert.return_value = {"insert_count": 10}
|
||||||
|
|
||||||
|
# Mock search operation - return mock results (data should be dict, not JSON string)
|
||||||
|
client.search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"distance": 0.2,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock query operation for keyword search (data should be dict, not JSON string)
|
||||||
|
client.query.return_value = [
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||||
|
"score": 0.9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||||
|
"score": 0.8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk3",
|
||||||
|
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||||
|
"score": 0.7,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def milvus_index(mock_milvus_client):
|
||||||
|
"""Create a MilvusIndex with mocked client."""
|
||||||
|
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||||
|
yield index
|
||||||
|
# No real cleanup needed since we're using mocks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
# Setup: collection doesn't exist initially, then exists after creation
|
||||||
|
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||||
|
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Verify collection was created and data was inserted
|
||||||
|
mock_milvus_client.create_collection.assert_called_once()
|
||||||
|
mock_milvus_client.insert.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the insert call had the right number of chunks
|
||||||
|
insert_call = mock_milvus_client.insert.call_args
|
||||||
|
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_vector(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
# Setup: Add chunks first
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Test vector search
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
mock_milvus_client.search.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Test keyword search
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Force BM25 search to fail
|
||||||
|
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||||
|
|
||||||
|
# Mock simple text search results
|
||||||
|
mock_milvus_client.query.return_value = [
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test keyword search that should fall back to simple text search
|
||||||
|
query_string = "Python"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) > 0, "Fallback search should return results"
|
||||||
|
|
||||||
|
# Verify that simple text search was used (query method called instead of search)
|
||||||
|
mock_milvus_client.query.assert_called_once()
|
||||||
|
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||||
|
|
||||||
|
# Verify the query uses parameterized filter with filter_params
|
||||||
|
query_call_args = mock_milvus_client.query.call_args
|
||||||
|
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
||||||
|
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
||||||
|
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
||||||
|
|
||||||
|
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
||||||
|
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||||
|
# Test collection deletion
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
|
||||||
|
await milvus_index.delete()
|
||||||
|
|
||||||
|
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
||||||
|
|
@ -94,7 +94,7 @@ async def test_query_unregistered_raises(vector_io_adapter):
|
||||||
|
|
||||||
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
||||||
fake_index = AsyncMock()
|
fake_index = AsyncMock()
|
||||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
vector_io_adapter.cache["db1"] = fake_index
|
||||||
|
|
||||||
chunks = ["chunk1", "chunk2"]
|
chunks = ["chunk1", "chunk2"]
|
||||||
await vector_io_adapter.insert_chunks("db1", chunks)
|
await vector_io_adapter.insert_chunks("db1", chunks)
|
||||||
|
|
@ -112,7 +112,7 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
||||||
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
|
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
|
||||||
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
||||||
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
||||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
vector_io_adapter.cache["db1"] = fake_index
|
||||||
|
|
||||||
response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1})
|
response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1})
|
||||||
|
|
||||||
|
|
@ -286,5 +286,7 @@ async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, t
|
||||||
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||||
await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
|
await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
|
||||||
|
|
||||||
|
loaded_file_info = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
|
||||||
|
assert loaded_file_info == {}
|
||||||
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||||
assert loaded_contents == []
|
assert loaded_contents == []
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.tools.rag_tool import RAGQueryConfig
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
ChunkMetadata,
|
ChunkMetadata,
|
||||||
|
|
@ -58,3 +59,14 @@ class TestRagQuery:
|
||||||
)
|
)
|
||||||
assert expected_metadata_string in result.content[1].text
|
assert expected_metadata_string in result.content[1].text
|
||||||
assert result.content is not None
|
assert result.content is not None
|
||||||
|
|
||||||
|
async def test_query_raises_incorrect_mode(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
RAGQueryConfig(mode="invalid_mode")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_accepts_valid_modes(self):
|
||||||
|
RAGQueryConfig() # Test default (vector)
|
||||||
|
RAGQueryConfig(mode="vector") # Test vector
|
||||||
|
RAGQueryConfig(mode="keyword") # Test keyword
|
||||||
|
RAGQueryConfig(mode="hybrid") # Test hybrid
|
||||||
|
|
|
||||||
|
|
@ -123,6 +123,7 @@ class TestVectorStore:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
assert content in DUMMY_PDF_TEXT_CHOICES
|
assert content in DUMMY_PDF_TEXT_CHOICES
|
||||||
|
|
||||||
|
@pytest.mark.allow_network
|
||||||
async def test_downloads_pdf_and_returns_content(self):
|
async def test_downloads_pdf_and_returns_content(self):
|
||||||
# Using GitHub to host the PDF file
|
# Using GitHub to host the PDF file
|
||||||
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
||||||
|
|
@ -135,6 +136,7 @@ class TestVectorStore:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
assert content in DUMMY_PDF_TEXT_CHOICES
|
assert content in DUMMY_PDF_TEXT_CHOICES
|
||||||
|
|
||||||
|
@pytest.mark.allow_network
|
||||||
async def test_downloads_pdf_and_returns_content_with_url_object(self):
|
async def test_downloads_pdf_and_returns_content_with_url_object(self):
|
||||||
# Using GitHub to host the PDF file
|
# Using GitHub to host the PDF file
|
||||||
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
||||||
|
|
|
||||||
14
uv.lock
generated
14
uv.lock
generated
|
|
@ -1324,6 +1324,7 @@ dev = [
|
||||||
{ name = "pytest-cov" },
|
{ name = "pytest-cov" },
|
||||||
{ name = "pytest-html" },
|
{ name = "pytest-html" },
|
||||||
{ name = "pytest-json-report" },
|
{ name = "pytest-json-report" },
|
||||||
|
{ name = "pytest-socket" },
|
||||||
{ name = "pytest-timeout" },
|
{ name = "pytest-timeout" },
|
||||||
{ name = "ruamel-yaml" },
|
{ name = "ruamel-yaml" },
|
||||||
{ name = "ruff" },
|
{ name = "ruff" },
|
||||||
|
|
@ -1432,6 +1433,7 @@ dev = [
|
||||||
{ name = "pytest-cov" },
|
{ name = "pytest-cov" },
|
||||||
{ name = "pytest-html" },
|
{ name = "pytest-html" },
|
||||||
{ name = "pytest-json-report" },
|
{ name = "pytest-json-report" },
|
||||||
|
{ name = "pytest-socket" },
|
||||||
{ name = "pytest-timeout" },
|
{ name = "pytest-timeout" },
|
||||||
{ name = "ruamel-yaml" },
|
{ name = "ruamel-yaml" },
|
||||||
{ name = "ruff" },
|
{ name = "ruff" },
|
||||||
|
|
@ -2545,6 +2547,18 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" },
|
{ url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-socket"
|
||||||
|
version = "0.7.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "pytest" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/05/ff/90c7e1e746baf3d62ce864c479fd53410b534818b9437413903596f81580/pytest_socket-0.7.0.tar.gz", hash = "sha256:71ab048cbbcb085c15a4423b73b619a8b35d6a307f46f78ea46be51b1b7e11b3", size = 12389, upload-time = "2024-01-28T20:17:23.177Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/19/58/5d14cb5cb59409e491ebe816c47bf81423cd03098ea92281336320ae5681/pytest_socket-0.7.0-py3-none-any.whl", hash = "sha256:7e0f4642177d55d317bbd58fc68c6bd9048d6eadb2d46a89307fa9221336ce45", size = 6754, upload-time = "2024-01-28T20:17:22.105Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest-timeout"
|
name = "pytest-timeout"
|
||||||
version = "2.4.0"
|
version = "2.4.0"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue