mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge branch 'main' into vectordb_name
This commit is contained in:
commit
b103ee9eb8
54 changed files with 4277 additions and 1773 deletions
9
.github/workflows/integration-auth-tests.yml
vendored
9
.github/workflows/integration-auth-tests.yml
vendored
|
@ -73,9 +73,12 @@ jobs:
|
||||||
server:
|
server:
|
||||||
port: 8321
|
port: 8321
|
||||||
EOF
|
EOF
|
||||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth.provider_config.type = "${{ matrix.auth-provider }}"' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth.provider_config.tls_cafile = "${{ env.KUBERNETES_CA_CERT_PATH }}"' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth.provider_config.issuer = "${{ env.KUBERNETES_ISSUER }}"' -i $run_dir/run.yaml
|
||||||
|
yq eval '.server.auth.provider_config.audience = "${{ env.KUBERNETES_AUDIENCE }}"' -i $run_dir/run.yaml
|
||||||
|
yq eval '.server.auth.provider_config.jwks.uri = "${{ env.KUBERNETES_API_SERVER_URL }}"' -i $run_dir/run.yaml
|
||||||
|
yq eval '.server.auth.provider_config.jwks.token = "${{ env.TOKEN }}"' -i $run_dir/run.yaml
|
||||||
cat $run_dir/run.yaml
|
cat $run_dir/run.yaml
|
||||||
|
|
||||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
70
.github/workflows/integration-sql-store-tests.yml
vendored
Normal file
70
.github/workflows/integration-sql-store-tests.yml
vendored
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
name: SqlStore Integration Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- 'llama_stack/providers/utils/sqlstore/**'
|
||||||
|
- 'tests/integration/sqlstore/**'
|
||||||
|
- 'uv.lock'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-postgres:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.12", "3.13"]
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:15
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: llamastack
|
||||||
|
POSTGRES_PASSWORD: llamastack
|
||||||
|
POSTGRES_DB: llamastack
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
options: >-
|
||||||
|
--health-cmd pg_isready
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
uses: ./.github/actions/setup-runner
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Run SqlStore Integration Tests
|
||||||
|
env:
|
||||||
|
ENABLE_POSTGRES_TESTS: "true"
|
||||||
|
POSTGRES_HOST: localhost
|
||||||
|
POSTGRES_PORT: 5432
|
||||||
|
POSTGRES_DB: llamastack
|
||||||
|
POSTGRES_USER: llamastack
|
||||||
|
POSTGRES_PASSWORD: llamastack
|
||||||
|
run: |
|
||||||
|
uv run pytest -sv tests/integration/providers/utils/sqlstore/
|
||||||
|
|
||||||
|
- name: Upload test logs
|
||||||
|
if: ${{ always() }}
|
||||||
|
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||||
|
with:
|
||||||
|
name: postgres-test-logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.python-version }}
|
||||||
|
path: |
|
||||||
|
*.log
|
||||||
|
retention-days: 1
|
|
@ -56,8 +56,8 @@ shields: []
|
||||||
server:
|
server:
|
||||||
port: 8321
|
port: 8321
|
||||||
auth:
|
auth:
|
||||||
provider_type: "oauth2_token"
|
provider_config:
|
||||||
config:
|
type: "oauth2_token"
|
||||||
jwks:
|
jwks:
|
||||||
uri: "https://my-token-issuing-svc.com/jwks"
|
uri: "https://my-token-issuing-svc.com/jwks"
|
||||||
```
|
```
|
||||||
|
@ -226,6 +226,8 @@ server:
|
||||||
|
|
||||||
### Authentication Configuration
|
### Authentication Configuration
|
||||||
|
|
||||||
|
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
||||||
|
|
||||||
The `auth` section configures authentication for the server. When configured, all API requests must include a valid Bearer token in the Authorization header:
|
The `auth` section configures authentication for the server. When configured, all API requests must include a valid Bearer token in the Authorization header:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -240,8 +242,8 @@ The server can be configured to use service account tokens for authorization, va
|
||||||
```yaml
|
```yaml
|
||||||
server:
|
server:
|
||||||
auth:
|
auth:
|
||||||
provider_type: "oauth2_token"
|
provider_config:
|
||||||
config:
|
type: "oauth2_token"
|
||||||
jwks:
|
jwks:
|
||||||
uri: "https://kubernetes.default.svc:8443/openid/v1/jwks"
|
uri: "https://kubernetes.default.svc:8443/openid/v1/jwks"
|
||||||
token: "${env.TOKEN:+}"
|
token: "${env.TOKEN:+}"
|
||||||
|
@ -325,13 +327,25 @@ You can easily validate a request by running:
|
||||||
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### GitHub Token Provider
|
||||||
|
Validates GitHub personal access tokens or OAuth tokens directly:
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
auth:
|
||||||
|
provider_config:
|
||||||
|
type: "github_token"
|
||||||
|
github_api_base_url: "https://api.github.com" # Or GitHub Enterprise URL
|
||||||
|
```
|
||||||
|
|
||||||
|
The provider fetches user information from GitHub and maps it to access attributes based on the `claims_mapping` configuration.
|
||||||
|
|
||||||
#### Custom Provider
|
#### Custom Provider
|
||||||
Validates tokens against a custom authentication endpoint:
|
Validates tokens against a custom authentication endpoint:
|
||||||
```yaml
|
```yaml
|
||||||
server:
|
server:
|
||||||
auth:
|
auth:
|
||||||
provider_type: "custom"
|
provider_config:
|
||||||
config:
|
type: "custom"
|
||||||
endpoint: "https://auth.example.com/validate" # URL of the auth endpoint
|
endpoint: "https://auth.example.com/validate" # URL of the auth endpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -416,8 +430,8 @@ clients.
|
||||||
server:
|
server:
|
||||||
port: 8321
|
port: 8321
|
||||||
auth:
|
auth:
|
||||||
provider_type: custom
|
provider_config:
|
||||||
config:
|
type: custom
|
||||||
endpoint: https://auth.example.com/validate
|
endpoint: https://auth.example.com/validate
|
||||||
quota:
|
quota:
|
||||||
kvstore:
|
kvstore:
|
||||||
|
|
|
@ -39,6 +39,13 @@ docker pull llama-stack/distribution-meta-reference-gpu
|
||||||
|
|
||||||
**Guides:** [Meta Reference GPU Guide](self_hosted_distro/meta-reference-gpu)
|
**Guides:** [Meta Reference GPU Guide](self_hosted_distro/meta-reference-gpu)
|
||||||
|
|
||||||
|
### 🖥️ Self-Hosted with NVIDA NeMo Microservices
|
||||||
|
|
||||||
|
**Use `nvidia` if you:**
|
||||||
|
- Want to use Llama Stack with NVIDIA NeMo Microservices
|
||||||
|
|
||||||
|
**Guides:** [NVIDIA Distribution Guide](self_hosted_distro/nvidia)
|
||||||
|
|
||||||
### ☁️ Managed Hosting
|
### ☁️ Managed Hosting
|
||||||
|
|
||||||
**Use remote-hosted endpoints if you:**
|
**Use remote-hosted endpoints if you:**
|
||||||
|
|
177
docs/source/distributions/self_hosted_distro/nvidia.md
Normal file
177
docs/source/distributions/self_hosted_distro/nvidia.md
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||||
|
# NVIDIA Distribution
|
||||||
|
|
||||||
|
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
| API | Provider(s) |
|
||||||
|
|-----|-------------|
|
||||||
|
| agents | `inline::meta-reference` |
|
||||||
|
| datasetio | `inline::localfs`, `remote::nvidia` |
|
||||||
|
| eval | `remote::nvidia` |
|
||||||
|
| inference | `remote::nvidia` |
|
||||||
|
| post_training | `remote::nvidia` |
|
||||||
|
| safety | `remote::nvidia` |
|
||||||
|
| scoring | `inline::basic` |
|
||||||
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `inline::rag-runtime` |
|
||||||
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||||
|
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
||||||
|
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||||
|
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||||
|
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||||
|
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||||
|
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||||
|
- `NVIDIA_GUARDRAILS_CONFIG_ID`: NVIDIA Guardrail Configuration ID (default: `self-check`)
|
||||||
|
- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`)
|
||||||
|
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||||
|
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||||
|
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
|
||||||
|
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||||
|
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
|
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
|
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
|
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
|
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
|
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
|
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
|
- `meta/llama-3.3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
|
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||||
|
- `nvidia/nv-embedqa-e5-v5 `
|
||||||
|
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||||
|
- `snowflake/arctic-embed-l `
|
||||||
|
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
### NVIDIA API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
|
||||||
|
|
||||||
|
### Deploy NeMo Microservices Platform
|
||||||
|
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
|
||||||
|
|
||||||
|
## Supported Services
|
||||||
|
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
|
||||||
|
|
||||||
|
### Inference: NVIDIA NIM
|
||||||
|
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
|
||||||
|
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
|
||||||
|
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
|
||||||
|
|
||||||
|
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
|
||||||
|
|
||||||
|
### Datasetio API: NeMo Data Store
|
||||||
|
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||||
|
|
||||||
|
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
|
||||||
|
|
||||||
|
### Eval API: NeMo Evaluator
|
||||||
|
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
|
||||||
|
|
||||||
|
### Post-Training API: NeMo Customizer
|
||||||
|
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
|
||||||
|
|
||||||
|
### Safety API: NeMo Guardrails
|
||||||
|
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
|
||||||
|
|
||||||
|
## Deploying models
|
||||||
|
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||||
|
|
||||||
|
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
|
||||||
|
```sh
|
||||||
|
# URL to NeMo NIM Proxy service
|
||||||
|
export NEMO_URL="http://nemo.test"
|
||||||
|
|
||||||
|
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
|
||||||
|
-H 'accept: application/json' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"name": "llama-3.2-1b-instruct",
|
||||||
|
"namespace": "meta",
|
||||||
|
"config": {
|
||||||
|
"model": "meta/llama-3.2-1b-instruct",
|
||||||
|
"nim_deployment": {
|
||||||
|
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
|
||||||
|
"image_tag": "1.8.3",
|
||||||
|
"pvc_size": "25Gi",
|
||||||
|
"gpu": 1,
|
||||||
|
"additional_envs": {
|
||||||
|
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
|
||||||
|
|
||||||
|
You can also remove a deployed NIM to free up GPU resources, if needed.
|
||||||
|
```sh
|
||||||
|
export NEMO_URL="http://nemo.test"
|
||||||
|
|
||||||
|
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running Llama Stack with NVIDIA
|
||||||
|
|
||||||
|
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=8321
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
--pull always \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
llamastack/distribution-nvidia \
|
||||||
|
--config /root/my-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||||
|
llama stack build --template nvidia --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port 8321 \
|
||||||
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via venv
|
||||||
|
|
||||||
|
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||||
|
llama stack build --template nvidia --image-type venv
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port 8321 \
|
||||||
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example Notebooks
|
||||||
|
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.
|
|
@ -12,6 +12,7 @@ Please refer to the remote provider documentation.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `db_path` | `<class 'str'>` | No | PydanticUndefined | |
|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | |
|
||||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
|
||||||
|
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -403,15 +403,16 @@ def _run_stack_build_command_from_build_config(
|
||||||
if template_name:
|
if template_name:
|
||||||
# copy run.yaml from template to build_dir instead of generating it again
|
# copy run.yaml from template to build_dir instead of generating it again
|
||||||
template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
|
template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
|
||||||
with importlib.resources.as_file(template_path) as path:
|
|
||||||
run_config_file = build_dir / f"{template_name}-run.yaml"
|
run_config_file = build_dir / f"{template_name}-run.yaml"
|
||||||
|
|
||||||
|
with importlib.resources.as_file(template_path) as path:
|
||||||
shutil.copy(path, run_config_file)
|
shutil.copy(path, run_config_file)
|
||||||
|
|
||||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||||
cprint(f"You can find the newly-built template here: {template_path}", color="blue", file=sys.stderr)
|
cprint(f"You can find the newly-built template here: {run_config_file}", color="blue", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
"You can run the new Llama Stack distro via: "
|
"You can run the new Llama Stack distro via: "
|
||||||
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "blue"),
|
+ colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
|
||||||
color="green",
|
color="green",
|
||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
|
@ -155,7 +155,11 @@ class StackRun(Subcommand):
|
||||||
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
|
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
|
||||||
if callable(getattr(args, arg)):
|
if callable(getattr(args, arg)):
|
||||||
continue
|
continue
|
||||||
if arg == "config" and template_name:
|
if arg == "config":
|
||||||
|
if template_name:
|
||||||
|
server_args.template = str(template_name)
|
||||||
|
else:
|
||||||
|
# Set the config file path
|
||||||
server_args.config = str(config_file)
|
server_args.config = str(config_file)
|
||||||
else:
|
else:
|
||||||
setattr(server_args, arg, getattr(args, arg))
|
setattr(server_args, arg, getattr(args, arg))
|
||||||
|
@ -168,6 +172,9 @@ class StackRun(Subcommand):
|
||||||
run_args.extend([str(args.port)])
|
run_args.extend([str(args.port)])
|
||||||
|
|
||||||
if config_file:
|
if config_file:
|
||||||
|
if template_name:
|
||||||
|
run_args.extend(["--template", str(template_name)])
|
||||||
|
else:
|
||||||
run_args.extend(["--config", str(config_file)])
|
run_args.extend(["--config", str(config_file)])
|
||||||
|
|
||||||
if args.env:
|
if args.env:
|
||||||
|
|
|
@ -6,9 +6,9 @@
|
||||||
|
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any, Literal, Self
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
@ -161,23 +161,113 @@ class LoggingConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2JWKSConfig(BaseModel):
|
||||||
|
# The JWKS URI for collecting public keys
|
||||||
|
uri: str
|
||||||
|
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||||
|
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2IntrospectionConfig(BaseModel):
|
||||||
|
url: str
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
send_secret_in_body: bool = False
|
||||||
|
|
||||||
|
|
||||||
class AuthProviderType(StrEnum):
|
class AuthProviderType(StrEnum):
|
||||||
"""Supported authentication provider types."""
|
"""Supported authentication provider types."""
|
||||||
|
|
||||||
OAUTH2_TOKEN = "oauth2_token"
|
OAUTH2_TOKEN = "oauth2_token"
|
||||||
|
GITHUB_TOKEN = "github_token"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2TokenAuthConfig(BaseModel):
|
||||||
|
"""Configuration for OAuth2 token authentication."""
|
||||||
|
|
||||||
|
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
|
||||||
|
audience: str = Field(default="llama-stack")
|
||||||
|
verify_tls: bool = Field(default=True)
|
||||||
|
tls_cafile: Path | None = Field(default=None)
|
||||||
|
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
||||||
|
claims_mapping: dict[str, str] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"sub": "roles",
|
||||||
|
"username": "roles",
|
||||||
|
"groups": "teams",
|
||||||
|
"team": "teams",
|
||||||
|
"project": "projects",
|
||||||
|
"tenant": "namespaces",
|
||||||
|
"namespace": "namespaces",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
|
||||||
|
introspection: OAuth2IntrospectionConfig | None = Field(
|
||||||
|
default=None, description="OAuth2 introspection configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@field_validator("claims_mapping")
|
||||||
|
def validate_claims_mapping(cls, v):
|
||||||
|
for key, value in v.items():
|
||||||
|
if not value:
|
||||||
|
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_mode(self) -> Self:
|
||||||
|
if not self.jwks and not self.introspection:
|
||||||
|
raise ValueError("One of jwks or introspection must be configured")
|
||||||
|
if self.jwks and self.introspection:
|
||||||
|
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAuthConfig(BaseModel):
|
||||||
|
"""Configuration for custom authentication."""
|
||||||
|
|
||||||
|
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
|
||||||
|
endpoint: str = Field(
|
||||||
|
...,
|
||||||
|
description="Custom authentication endpoint URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubTokenAuthConfig(BaseModel):
|
||||||
|
"""Configuration for GitHub token authentication."""
|
||||||
|
|
||||||
|
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
|
||||||
|
github_api_base_url: str = Field(
|
||||||
|
default="https://api.github.com",
|
||||||
|
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
|
||||||
|
)
|
||||||
|
claims_mapping: dict[str, str] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"login": "roles",
|
||||||
|
"organizations": "teams",
|
||||||
|
},
|
||||||
|
description="Mapping from GitHub user fields to access attributes",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
AuthProviderConfig = Annotated[
|
||||||
|
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationConfig(BaseModel):
|
class AuthenticationConfig(BaseModel):
|
||||||
provider_type: AuthProviderType = Field(
|
"""Top-level authentication configuration."""
|
||||||
|
|
||||||
|
provider_config: AuthProviderConfig = Field(
|
||||||
...,
|
...,
|
||||||
description="Type of authentication provider",
|
description="Authentication provider configuration",
|
||||||
)
|
)
|
||||||
config: dict[str, Any] = Field(
|
access_policy: list[AccessRule] = Field(
|
||||||
...,
|
default=[],
|
||||||
description="Provider-specific configuration",
|
description="Rules for determining access to resources",
|
||||||
)
|
)
|
||||||
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationRequiredError(Exception):
|
class AuthenticationRequiredError(Exception):
|
||||||
|
|
|
@ -87,8 +87,12 @@ class AuthenticationMiddleware:
|
||||||
headers = dict(scope.get("headers", []))
|
headers = dict(scope.get("headers", []))
|
||||||
auth_header = headers.get(b"authorization", b"").decode()
|
auth_header = headers.get(b"authorization", b"").decode()
|
||||||
|
|
||||||
if not auth_header or not auth_header.startswith("Bearer "):
|
if not auth_header:
|
||||||
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
error_msg = self.auth_provider.get_auth_error_message(scope)
|
||||||
|
return await self._send_auth_error(send, error_msg)
|
||||||
|
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
return await self._send_auth_error(send, "Invalid Authorization header format")
|
||||||
|
|
||||||
token = auth_header.split("Bearer ", 1)[1]
|
token = auth_header.split("Bearer ", 1)[1]
|
||||||
|
|
||||||
|
|
|
@ -8,15 +8,19 @@ import ssl
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from asyncio import Lock
|
from asyncio import Lock
|
||||||
from pathlib import Path
|
from urllib.parse import parse_qs, urlparse
|
||||||
from typing import Self
|
|
||||||
from urllib.parse import parse_qs
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AuthenticationConfig,
|
||||||
|
CustomAuthConfig,
|
||||||
|
GitHubTokenAuthConfig,
|
||||||
|
OAuth2TokenAuthConfig,
|
||||||
|
User,
|
||||||
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
@ -38,9 +42,7 @@ class AuthRequestContext(BaseModel):
|
||||||
|
|
||||||
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||||
|
|
||||||
params: dict[str, list[str]] = Field(
|
params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
|
||||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthRequest(BaseModel):
|
class AuthRequest(BaseModel):
|
||||||
|
@ -62,6 +64,10 @@ class AuthProvider(ABC):
|
||||||
"""Clean up any resources."""
|
"""Clean up any resources."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||||
|
"""Return provider-specific authentication error message."""
|
||||||
|
return "Authentication required"
|
||||||
|
|
||||||
|
|
||||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||||
attributes: dict[str, list[str]] = {}
|
attributes: dict[str, list[str]] = {}
|
||||||
|
@ -81,56 +87,6 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
class OAuth2JWKSConfig(BaseModel):
|
|
||||||
# The JWKS URI for collecting public keys
|
|
||||||
uri: str
|
|
||||||
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
|
||||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2IntrospectionConfig(BaseModel):
|
|
||||||
url: str
|
|
||||||
client_id: str
|
|
||||||
client_secret: str
|
|
||||||
send_secret_in_body: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
|
||||||
audience: str = "llama-stack"
|
|
||||||
verify_tls: bool = True
|
|
||||||
tls_cafile: Path | None = None
|
|
||||||
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
|
||||||
claims_mapping: dict[str, str] = Field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"sub": "roles",
|
|
||||||
"username": "roles",
|
|
||||||
"groups": "teams",
|
|
||||||
"team": "teams",
|
|
||||||
"project": "projects",
|
|
||||||
"tenant": "namespaces",
|
|
||||||
"namespace": "namespaces",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
jwks: OAuth2JWKSConfig | None
|
|
||||||
introspection: OAuth2IntrospectionConfig | None = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@field_validator("claims_mapping")
|
|
||||||
def validate_claims_mapping(cls, v):
|
|
||||||
for key, value in v.items():
|
|
||||||
if not value:
|
|
||||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_mode(self) -> Self:
|
|
||||||
if not self.jwks and not self.introspection:
|
|
||||||
raise ValueError("One of jwks or introspection must be configured")
|
|
||||||
if self.jwks and self.introspection:
|
|
||||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenAuthProvider(AuthProvider):
|
class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
"""
|
"""
|
||||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||||
|
@ -138,7 +94,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
This should be the standard authentication provider for most use cases.
|
This should be the standard authentication provider for most use cases.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
def __init__(self, config: OAuth2TokenAuthConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._jwks_at: float = 0.0
|
self._jwks_at: float = 0.0
|
||||||
self._jwks: dict[str, str] = {}
|
self._jwks: dict[str, str] = {}
|
||||||
|
@ -170,7 +126,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
issuer=self.config.issuer,
|
issuer=self.config.issuer,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise ValueError(f"Invalid JWT token: {token}") from exc
|
raise ValueError("Invalid JWT token") from exc
|
||||||
|
|
||||||
# There are other standard claims, the most relevant of which is `scope`.
|
# There are other standard claims, the most relevant of which is `scope`.
|
||||||
# We should incorporate these into the access attributes.
|
# We should incorporate these into the access attributes.
|
||||||
|
@ -232,6 +188,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
async def close(self):
|
async def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||||
|
"""Return OAuth2-specific authentication error message."""
|
||||||
|
if self.config.issuer:
|
||||||
|
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
|
||||||
|
elif self.config.introspection:
|
||||||
|
# Extract domain from introspection URL for a cleaner message
|
||||||
|
domain = urlparse(self.config.introspection.url).netloc
|
||||||
|
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
|
||||||
|
else:
|
||||||
|
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
|
||||||
|
|
||||||
async def _refresh_jwks(self) -> None:
|
async def _refresh_jwks(self) -> None:
|
||||||
"""
|
"""
|
||||||
Refresh the JWKS cache.
|
Refresh the JWKS cache.
|
||||||
|
@ -264,14 +231,10 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
self._jwks_at = time.time()
|
self._jwks_at = time.time()
|
||||||
|
|
||||||
|
|
||||||
class CustomAuthProviderConfig(BaseModel):
|
|
||||||
endpoint: str
|
|
||||||
|
|
||||||
|
|
||||||
class CustomAuthProvider(AuthProvider):
|
class CustomAuthProvider(AuthProvider):
|
||||||
"""Custom authentication provider that uses an external endpoint."""
|
"""Custom authentication provider that uses an external endpoint."""
|
||||||
|
|
||||||
def __init__(self, config: CustomAuthProviderConfig):
|
def __init__(self, config: CustomAuthConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
|
@ -317,7 +280,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
auth_response = AuthResponse(**response_data)
|
auth_response = AuthResponse(**response_data)
|
||||||
return User(auth_response.principal, auth_response.attributes)
|
return User(principal=auth_response.principal, attributes=auth_response.attributes)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error parsing authentication response")
|
logger.exception("Error parsing authentication response")
|
||||||
raise ValueError("Invalid authentication response format") from e
|
raise ValueError("Invalid authentication response format") from e
|
||||||
|
@ -338,15 +301,88 @@ class CustomAuthProvider(AuthProvider):
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
|
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||||
|
"""Return custom auth provider-specific authentication error message."""
|
||||||
|
domain = urlparse(self.config.endpoint).netloc
|
||||||
|
if domain:
|
||||||
|
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
|
||||||
|
else:
|
||||||
|
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubTokenAuthProvider(AuthProvider):
|
||||||
|
"""
|
||||||
|
GitHub token authentication provider that validates GitHub access tokens directly.
|
||||||
|
|
||||||
|
This provider accepts GitHub personal access tokens or OAuth tokens and verifies
|
||||||
|
them against the GitHub API to get user information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: GitHubTokenAuthConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
|
"""Validate a GitHub token by calling the GitHub API.
|
||||||
|
|
||||||
|
This validates tokens issued by GitHub (personal access tokens or OAuth tokens).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.warning(f"GitHub token validation failed: {e}")
|
||||||
|
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
|
||||||
|
|
||||||
|
principal = user_info["user"]["login"]
|
||||||
|
|
||||||
|
github_data = {
|
||||||
|
"login": user_info["user"]["login"],
|
||||||
|
"id": str(user_info["user"]["id"]),
|
||||||
|
"organizations": user_info.get("organizations", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
|
||||||
|
|
||||||
|
return User(
|
||||||
|
principal=principal,
|
||||||
|
attributes=access_attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Clean up any resources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||||
|
"""Return GitHub-specific authentication error message."""
|
||||||
|
return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict:
|
||||||
|
"""Fetch user info and organizations from GitHub API."""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Accept": "application/vnd.github.v3+json",
|
||||||
|
"User-Agent": "llama-stack",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0)
|
||||||
|
user_response.raise_for_status()
|
||||||
|
user_data = user_response.json()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user": user_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||||
"""Factory function to create the appropriate auth provider."""
|
"""Factory function to create the appropriate auth provider."""
|
||||||
provider_type = config.provider_type.lower()
|
provider_config = config.provider_config
|
||||||
|
|
||||||
if provider_type == "custom":
|
if isinstance(provider_config, CustomAuthConfig):
|
||||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
return CustomAuthProvider(provider_config)
|
||||||
elif provider_type == "oauth2_token":
|
elif isinstance(provider_config, OAuth2TokenAuthConfig):
|
||||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
return OAuth2TokenAuthProvider(provider_config)
|
||||||
|
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||||
|
return GitHubTokenAuthProvider(provider_config)
|
||||||
else:
|
else:
|
||||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
||||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
|
||||||
|
|
|
@ -33,7 +33,11 @@ from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AuthenticationRequiredError,
|
||||||
|
LoggingConfig,
|
||||||
|
StackRunConfig,
|
||||||
|
)
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
|
@ -217,7 +221,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
# Get auth attributes from the request scope
|
# Get auth attributes from the request scope
|
||||||
user_attributes = request.scope.get("user_attributes", {})
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
principal = request.scope.get("principal", "")
|
principal = request.scope.get("principal", "")
|
||||||
user = User(principal, user_attributes)
|
user = User(principal=principal, attributes=user_attributes)
|
||||||
|
|
||||||
await log_request_pre_validation(request)
|
await log_request_pre_validation(request)
|
||||||
|
|
||||||
|
@ -405,13 +409,13 @@ def main(args: argparse.Namespace | None = None):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
log_line = ""
|
log_line = ""
|
||||||
if args.config:
|
if hasattr(args, "config") and args.config:
|
||||||
# if the user provided a config file, use it, even if template was specified
|
# if the user provided a config file, use it, even if template was specified
|
||||||
config_file = Path(args.config)
|
config_file = Path(args.config)
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise ValueError(f"Config file {config_file} does not exist")
|
raise ValueError(f"Config file {config_file} does not exist")
|
||||||
log_line = f"Using config file: {config_file}"
|
log_line = f"Using config file: {config_file}"
|
||||||
elif args.template:
|
elif hasattr(args, "template") and args.template:
|
||||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise ValueError(f"Template {args.template} does not exist")
|
raise ValueError(f"Template {args.template} does not exist")
|
||||||
|
@ -455,7 +459,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
|
|
||||||
# Add authentication middleware if configured
|
# Add authentication middleware if configured
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||||
else:
|
else:
|
||||||
if config.server.quota:
|
if config.server.quota:
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
KVStoreConfig,
|
KVStoreConfig,
|
||||||
|
@ -19,6 +19,7 @@ from llama_stack.schema_utils import json_schema_type
|
||||||
class MilvusVectorIOConfig(BaseModel):
|
class MilvusVectorIOConfig(BaseModel):
|
||||||
db_path: str
|
db_path: str
|
||||||
kvstore: KVStoreConfig
|
kvstore: KVStoreConfig
|
||||||
|
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
|
|
|
@ -154,10 +154,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
for vector_db_data in stored_vector_dbs:
|
for vector_db_data in stored_vector_dbs:
|
||||||
vector_db = VectorDB.mdel_validate_json(vector_db_data)
|
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(
|
||||||
vector_db,
|
vector_db,
|
||||||
index=await MilvusIndex(
|
index=MilvusIndex(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
collection_name=vector_db.identifier,
|
collection_name=vector_db.identifier,
|
||||||
consistency_level=self.config.consistency_level,
|
consistency_level=self.config.consistency_level,
|
||||||
|
@ -259,6 +259,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
await self.kvstore.delete(key)
|
await self.kvstore.delete(key)
|
||||||
|
if store_id in self.openai_vector_stores:
|
||||||
|
del self.openai_vector_stores[store_id]
|
||||||
|
|
||||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
"""Load all vector store metadata from persistent storage."""
|
"""Load all vector store metadata from persistent storage."""
|
||||||
|
@ -377,6 +379,29 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
logger.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
|
logger.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||||
|
"""Update vector store file metadata in Milvus database."""
|
||||||
|
try:
|
||||||
|
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||||
|
return
|
||||||
|
|
||||||
|
file_data = [
|
||||||
|
{
|
||||||
|
"store_file_id": f"{store_id}_{file_id}",
|
||||||
|
"store_id": store_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
"file_info": json.dumps(file_info),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.client.upsert,
|
||||||
|
collection_name="openai_vector_store_files",
|
||||||
|
data=file_data,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||||
"""Load vector store file contents from Milvus database."""
|
"""Load vector store file contents from Milvus database."""
|
||||||
try:
|
try:
|
||||||
|
@ -405,29 +430,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
logger.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
|
logger.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
|
||||||
"""Update vector store file metadata in Milvus database."""
|
|
||||||
try:
|
|
||||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
|
||||||
return
|
|
||||||
|
|
||||||
file_data = [
|
|
||||||
{
|
|
||||||
"store_file_id": f"{store_id}_{file_id}",
|
|
||||||
"store_id": store_id,
|
|
||||||
"file_id": file_id,
|
|
||||||
"file_info": json.dumps(file_info),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
await asyncio.to_thread(
|
|
||||||
self.client.upsert,
|
|
||||||
collection_name="openai_vector_store_files",
|
|
||||||
data=file_data,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||||
"""Delete vector store file metadata from Milvus database."""
|
"""Delete vector store file metadata from Milvus database."""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||||
|
from .sqlstore import SqlStoreType
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
||||||
|
|
||||||
|
@ -71,9 +72,18 @@ class AuthorizedSqlStore:
|
||||||
:param sql_store: Base SqlStore implementation to wrap
|
:param sql_store: Base SqlStore implementation to wrap
|
||||||
"""
|
"""
|
||||||
self.sql_store = sql_store
|
self.sql_store = sql_store
|
||||||
|
self._detect_database_type()
|
||||||
self._validate_sql_optimized_policy()
|
self._validate_sql_optimized_policy()
|
||||||
|
|
||||||
|
def _detect_database_type(self) -> None:
|
||||||
|
"""Detect the database type from the underlying SQL store."""
|
||||||
|
if not hasattr(self.sql_store, "config"):
|
||||||
|
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
||||||
|
|
||||||
|
self.database_type = self.sql_store.config.type
|
||||||
|
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
def _validate_sql_optimized_policy(self) -> None:
|
def _validate_sql_optimized_policy(self) -> None:
|
||||||
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
||||||
|
|
||||||
|
@ -181,6 +191,50 @@ class AuthorizedSqlStore:
|
||||||
else:
|
else:
|
||||||
return self._build_conservative_where_clause()
|
return self._build_conservative_where_clause()
|
||||||
|
|
||||||
|
def _json_extract(self, column: str, path: str) -> str:
|
||||||
|
"""Extract JSON value (keeping JSON type).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: The JSON column name
|
||||||
|
path: The JSON path (e.g., 'roles', 'teams')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQL expression to extract JSON value
|
||||||
|
"""
|
||||||
|
if self.database_type == SqlStoreType.postgres:
|
||||||
|
return f"{column}->'{path}'"
|
||||||
|
elif self.database_type == SqlStoreType.sqlite:
|
||||||
|
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
|
def _json_extract_text(self, column: str, path: str) -> str:
|
||||||
|
"""Extract JSON value as text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: The JSON column name
|
||||||
|
path: The JSON path (e.g., 'roles', 'teams')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQL expression to extract JSON value as text
|
||||||
|
"""
|
||||||
|
if self.database_type == SqlStoreType.postgres:
|
||||||
|
return f"{column}->>'{path}'"
|
||||||
|
elif self.database_type == SqlStoreType.sqlite:
|
||||||
|
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
|
def _get_public_access_conditions(self) -> list[str]:
|
||||||
|
"""Get the SQL conditions for public access."""
|
||||||
|
if self.database_type == SqlStoreType.postgres:
|
||||||
|
# Postgres stores JSON null as 'null'
|
||||||
|
return ["access_attributes::text = 'null'"]
|
||||||
|
elif self.database_type == SqlStoreType.sqlite:
|
||||||
|
return ["access_attributes = 'null'"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
def _build_default_policy_where_clause(self) -> str:
|
def _build_default_policy_where_clause(self) -> str:
|
||||||
"""Build SQL WHERE clause for the default policy.
|
"""Build SQL WHERE clause for the default policy.
|
||||||
|
|
||||||
|
@ -189,29 +243,32 @@ class AuthorizedSqlStore:
|
||||||
"""
|
"""
|
||||||
current_user = get_authenticated_user()
|
current_user = get_authenticated_user()
|
||||||
|
|
||||||
|
base_conditions = self._get_public_access_conditions()
|
||||||
if not current_user or not current_user.attributes:
|
if not current_user or not current_user.attributes:
|
||||||
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
|
# Only allow public records
|
||||||
|
return f"({' OR '.join(base_conditions)})"
|
||||||
else:
|
else:
|
||||||
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
|
|
||||||
|
|
||||||
user_attr_conditions = []
|
user_attr_conditions = []
|
||||||
|
|
||||||
for attr_key, user_values in current_user.attributes.items():
|
for attr_key, user_values in current_user.attributes.items():
|
||||||
if user_values:
|
if user_values:
|
||||||
value_conditions = []
|
value_conditions = []
|
||||||
for value in user_values:
|
for value in user_values:
|
||||||
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
|
# Check if JSON array contains the value
|
||||||
|
escaped_value = value.replace("'", "''")
|
||||||
|
json_text = self._json_extract_text("access_attributes", attr_key)
|
||||||
|
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
|
||||||
|
|
||||||
if value_conditions:
|
if value_conditions:
|
||||||
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
|
# Check if the category is missing (NULL)
|
||||||
|
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
|
||||||
user_matches_category = f"({' OR '.join(value_conditions)})"
|
user_matches_category = f"({' OR '.join(value_conditions)})"
|
||||||
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
||||||
|
|
||||||
if user_attr_conditions:
|
if user_attr_conditions:
|
||||||
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
||||||
base_conditions.append(all_requirements_met)
|
base_conditions.append(all_requirements_met)
|
||||||
return f"({' OR '.join(base_conditions)})"
|
|
||||||
else:
|
|
||||||
return f"({' OR '.join(base_conditions)})"
|
return f"({' OR '.join(base_conditions)})"
|
||||||
|
|
||||||
def _build_conservative_where_clause(self) -> str:
|
def _build_conservative_where_clause(self) -> str:
|
||||||
|
@ -222,5 +279,8 @@ class AuthorizedSqlStore:
|
||||||
current_user = get_authenticated_user()
|
current_user = get_authenticated_user()
|
||||||
|
|
||||||
if not current_user:
|
if not current_user:
|
||||||
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
|
# Only allow public records
|
||||||
|
base_conditions = self._get_public_access_conditions()
|
||||||
|
return f"({' OR '.join(base_conditions)})"
|
||||||
|
|
||||||
return "1=1"
|
return "1=1"
|
||||||
|
|
|
@ -4,9 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
@ -19,7 +18,7 @@ from .api import SqlStore
|
||||||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||||
|
|
||||||
|
|
||||||
class SqlStoreType(Enum):
|
class SqlStoreType(StrEnum):
|
||||||
sqlite = "sqlite"
|
sqlite = "sqlite"
|
||||||
postgres = "postgres"
|
postgres = "postgres"
|
||||||
|
|
||||||
|
@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
|
||||||
db_path: str = Field(
|
db_path: str = Field(
|
||||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||||
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
||||||
|
@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
|
|
||||||
|
|
||||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
type: Literal["postgres"] = SqlStoreType.postgres.value
|
type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
port: int = 5432
|
port: int = 5432
|
||||||
db: str = "llamastack"
|
db: str = "llamastack"
|
||||||
|
@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
|
||||||
|
|
||||||
|
|
||||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||||
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
|
if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
|
||||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||||
|
|
||||||
impl = SqlAlchemySqlStoreImpl(config)
|
impl = SqlAlchemySqlStoreImpl(config)
|
||||||
|
|
7
llama_stack/templates/nvidia/__init__.py
Normal file
7
llama_stack/templates/nvidia/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .nvidia import get_distribution_template # noqa: F401
|
29
llama_stack/templates/nvidia/build.yaml
Normal file
29
llama_stack/templates/nvidia/build.yaml
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
version: 2
|
||||||
|
distribution_spec:
|
||||||
|
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::nvidia
|
||||||
|
vector_io:
|
||||||
|
- inline::faiss
|
||||||
|
safety:
|
||||||
|
- remote::nvidia
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
eval:
|
||||||
|
- remote::nvidia
|
||||||
|
post_training:
|
||||||
|
- remote::nvidia
|
||||||
|
datasetio:
|
||||||
|
- inline::localfs
|
||||||
|
- remote::nvidia
|
||||||
|
scoring:
|
||||||
|
- inline::basic
|
||||||
|
tool_runtime:
|
||||||
|
- inline::rag-runtime
|
||||||
|
image_type: conda
|
||||||
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
|
- sqlalchemy[asyncio]
|
149
llama_stack/templates/nvidia/doc_template.md
Normal file
149
llama_stack/templates/nvidia/doc_template.md
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
# NVIDIA Distribution
|
||||||
|
|
||||||
|
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
{{ providers_table }}
|
||||||
|
|
||||||
|
{% if run_config_env_vars %}
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||||
|
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if default_models %}
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
{% for model in default_models %}
|
||||||
|
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
### NVIDIA API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
|
||||||
|
|
||||||
|
### Deploy NeMo Microservices Platform
|
||||||
|
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
|
||||||
|
|
||||||
|
## Supported Services
|
||||||
|
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
|
||||||
|
|
||||||
|
### Inference: NVIDIA NIM
|
||||||
|
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
|
||||||
|
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
|
||||||
|
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
|
||||||
|
|
||||||
|
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
|
||||||
|
|
||||||
|
### Datasetio API: NeMo Data Store
|
||||||
|
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||||
|
|
||||||
|
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
|
||||||
|
|
||||||
|
### Eval API: NeMo Evaluator
|
||||||
|
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
|
||||||
|
|
||||||
|
### Post-Training API: NeMo Customizer
|
||||||
|
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
|
||||||
|
|
||||||
|
### Safety API: NeMo Guardrails
|
||||||
|
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
|
||||||
|
|
||||||
|
## Deploying models
|
||||||
|
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||||
|
|
||||||
|
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
|
||||||
|
```sh
|
||||||
|
# URL to NeMo NIM Proxy service
|
||||||
|
export NEMO_URL="http://nemo.test"
|
||||||
|
|
||||||
|
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
|
||||||
|
-H 'accept: application/json' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"name": "llama-3.2-1b-instruct",
|
||||||
|
"namespace": "meta",
|
||||||
|
"config": {
|
||||||
|
"model": "meta/llama-3.2-1b-instruct",
|
||||||
|
"nim_deployment": {
|
||||||
|
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
|
||||||
|
"image_tag": "1.8.3",
|
||||||
|
"pvc_size": "25Gi",
|
||||||
|
"gpu": 1,
|
||||||
|
"additional_envs": {
|
||||||
|
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
|
||||||
|
|
||||||
|
You can also remove a deployed NIM to free up GPU resources, if needed.
|
||||||
|
```sh
|
||||||
|
export NEMO_URL="http://nemo.test"
|
||||||
|
|
||||||
|
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running Llama Stack with NVIDIA
|
||||||
|
|
||||||
|
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=8321
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
--pull always \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
llamastack/distribution-{{ name }} \
|
||||||
|
--config /root/my-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||||
|
llama stack build --template nvidia --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port 8321 \
|
||||||
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via venv
|
||||||
|
|
||||||
|
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||||
|
llama stack build --template nvidia --image-type venv
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port 8321 \
|
||||||
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example Notebooks
|
||||||
|
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.
|
150
llama_stack/templates/nvidia/nvidia.py
Normal file
150
llama_stack/templates/nvidia/nvidia.py
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||||
|
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
|
||||||
|
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
||||||
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": ["remote::nvidia"],
|
||||||
|
"vector_io": ["inline::faiss"],
|
||||||
|
"safety": ["remote::nvidia"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"eval": ["remote::nvidia"],
|
||||||
|
"post_training": ["remote::nvidia"],
|
||||||
|
"datasetio": ["inline::localfs", "remote::nvidia"],
|
||||||
|
"scoring": ["inline::basic"],
|
||||||
|
"tool_runtime": ["inline::rag-runtime"],
|
||||||
|
}
|
||||||
|
|
||||||
|
inference_provider = Provider(
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
|
config=NVIDIAConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
safety_provider = Provider(
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
|
config=NVIDIASafetyConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
datasetio_provider = Provider(
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
|
config=NvidiaDatasetIOConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
eval_provider = Provider(
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
|
config=NVIDIAEvalConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
inference_model = ModelInput(
|
||||||
|
model_id="${env.INFERENCE_MODEL}",
|
||||||
|
provider_id="nvidia",
|
||||||
|
)
|
||||||
|
safety_model = ModelInput(
|
||||||
|
model_id="${env.SAFETY_MODEL}",
|
||||||
|
provider_id="nvidia",
|
||||||
|
)
|
||||||
|
|
||||||
|
available_models = {
|
||||||
|
"nvidia": MODEL_ENTRIES,
|
||||||
|
}
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::rag",
|
||||||
|
provider_id="rag-runtime",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
default_models = get_model_registry(available_models)
|
||||||
|
return DistributionTemplate(
|
||||||
|
name="nvidia",
|
||||||
|
distro_type="self_hosted",
|
||||||
|
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
|
||||||
|
container_image=None,
|
||||||
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
|
providers=providers,
|
||||||
|
available_models_by_provider=available_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [inference_provider],
|
||||||
|
"datasetio": [datasetio_provider],
|
||||||
|
"eval": [eval_provider],
|
||||||
|
},
|
||||||
|
default_models=default_models,
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
),
|
||||||
|
"run-with-safety.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [
|
||||||
|
inference_provider,
|
||||||
|
safety_provider,
|
||||||
|
],
|
||||||
|
"eval": [eval_provider],
|
||||||
|
},
|
||||||
|
default_models=[inference_model, safety_model],
|
||||||
|
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"NVIDIA_API_KEY": (
|
||||||
|
"",
|
||||||
|
"NVIDIA API Key",
|
||||||
|
),
|
||||||
|
"NVIDIA_APPEND_API_VERSION": (
|
||||||
|
"True",
|
||||||
|
"Whether to append the API version to the base_url",
|
||||||
|
),
|
||||||
|
## Nemo Customizer related variables
|
||||||
|
"NVIDIA_DATASET_NAMESPACE": (
|
||||||
|
"default",
|
||||||
|
"NVIDIA Dataset Namespace",
|
||||||
|
),
|
||||||
|
"NVIDIA_PROJECT_ID": (
|
||||||
|
"test-project",
|
||||||
|
"NVIDIA Project ID",
|
||||||
|
),
|
||||||
|
"NVIDIA_CUSTOMIZER_URL": (
|
||||||
|
"https://customizer.api.nvidia.com",
|
||||||
|
"NVIDIA Customizer URL",
|
||||||
|
),
|
||||||
|
"NVIDIA_OUTPUT_MODEL_DIR": (
|
||||||
|
"test-example-model@v1",
|
||||||
|
"NVIDIA Output Model Directory",
|
||||||
|
),
|
||||||
|
"GUARDRAILS_SERVICE_URL": (
|
||||||
|
"http://0.0.0.0:7331",
|
||||||
|
"URL for the NeMo Guardrails Service",
|
||||||
|
),
|
||||||
|
"NVIDIA_GUARDRAILS_CONFIG_ID": (
|
||||||
|
"self-check",
|
||||||
|
"NVIDIA Guardrail Configuration ID",
|
||||||
|
),
|
||||||
|
"NVIDIA_EVALUATOR_URL": (
|
||||||
|
"http://0.0.0.0:7331",
|
||||||
|
"URL for the NeMo Evaluator Service",
|
||||||
|
),
|
||||||
|
"INFERENCE_MODEL": (
|
||||||
|
"Llama3.1-8B-Instruct",
|
||||||
|
"Inference model",
|
||||||
|
),
|
||||||
|
"SAFETY_MODEL": (
|
||||||
|
"meta/llama-3.1-8b-instruct",
|
||||||
|
"Name of the model to use for safety",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
119
llama_stack/templates/nvidia/run-with-safety.yaml
Normal file
119
llama_stack/templates/nvidia/run-with-safety.yaml
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
version: 2
|
||||||
|
image_name: nvidia
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- post_training
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||||
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
|
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
|
||||||
|
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/faiss_store.db
|
||||||
|
safety:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
|
||||||
|
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/agents_store.db
|
||||||
|
responses_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/responses_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
|
||||||
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
|
eval:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}
|
||||||
|
post_training:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
|
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
|
||||||
|
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
|
||||||
|
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
|
||||||
|
datasetio:
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/localfs_datasetio.db
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
|
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
|
||||||
|
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
|
||||||
|
datasets_url: ${env.NVIDIA_DATASETS_URL:=http://nemo.test}
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
config: {}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
|
||||||
|
inference_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: ${env.SAFETY_MODEL}
|
||||||
|
provider_id: nvidia
|
||||||
|
model_type: llm
|
||||||
|
shields:
|
||||||
|
- shield_id: ${env.SAFETY_MODEL}
|
||||||
|
provider_id: nvidia
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
server:
|
||||||
|
port: 8321
|
226
llama_stack/templates/nvidia/run.yaml
Normal file
226
llama_stack/templates/nvidia/run.yaml
Normal file
|
@ -0,0 +1,226 @@
|
||||||
|
version: 2
|
||||||
|
image_name: nvidia
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- post_training
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||||
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
|
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/faiss_store.db
|
||||||
|
safety:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
|
||||||
|
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/agents_store.db
|
||||||
|
responses_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/responses_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
|
||||||
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
|
eval:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}
|
||||||
|
post_training:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
|
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
|
||||||
|
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
|
||||||
|
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
|
||||||
|
datasetio:
|
||||||
|
- provider_id: nvidia
|
||||||
|
provider_type: remote::nvidia
|
||||||
|
config:
|
||||||
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
|
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
|
||||||
|
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
|
||||||
|
datasets_url: ${env.NVIDIA_DATASETS_URL:=http://nemo.test}
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
config: {}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
|
||||||
|
inference_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama3-8b-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama3-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3-8B-Instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama3-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama3-70b-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3-70B-Instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama-3.1-8b-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama-3.1-70b-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama-3.1-405b-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.1-405b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.1-405b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama-3.2-1b-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama-3.2-3b-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama-3.2-11b-vision-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama-3.2-90b-vision-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama-3.3-70b-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 2048
|
||||||
|
context_length: 8192
|
||||||
|
model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||||
|
model_type: embedding
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 1024
|
||||||
|
context_length: 512
|
||||||
|
model_id: nvidia/nv-embedqa-e5-v5
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: nvidia/nv-embedqa-e5-v5
|
||||||
|
model_type: embedding
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 4096
|
||||||
|
context_length: 512
|
||||||
|
model_id: nvidia/nv-embedqa-mistral-7b-v2
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
|
||||||
|
model_type: embedding
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 1024
|
||||||
|
context_length: 512
|
||||||
|
model_id: snowflake/arctic-embed-l
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: snowflake/arctic-embed-l
|
||||||
|
model_type: embedding
|
||||||
|
shields: []
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
server:
|
||||||
|
port: 8321
|
6
llama_stack/ui/app/api/auth/[...nextauth]/route.ts
Normal file
6
llama_stack/ui/app/api/auth/[...nextauth]/route.ts
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
import NextAuth from "next-auth";
|
||||||
|
import { authOptions } from "@/lib/auth";
|
||||||
|
|
||||||
|
const handler = NextAuth(authOptions);
|
||||||
|
|
||||||
|
export { handler as GET, handler as POST };
|
118
llama_stack/ui/app/auth/signin/page.tsx
Normal file
118
llama_stack/ui/app/auth/signin/page.tsx
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { signIn, signOut, useSession } from "next-auth/react";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
Card,
|
||||||
|
CardContent,
|
||||||
|
CardDescription,
|
||||||
|
CardHeader,
|
||||||
|
CardTitle,
|
||||||
|
} from "@/components/ui/card";
|
||||||
|
import { Copy, Check, Home, Github } from "lucide-react";
|
||||||
|
import { useState } from "react";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
|
||||||
|
export default function SignInPage() {
|
||||||
|
const { data: session, status } = useSession();
|
||||||
|
const [copied, setCopied] = useState(false);
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
const handleCopyToken = async () => {
|
||||||
|
if (session?.accessToken) {
|
||||||
|
await navigator.clipboard.writeText(session.accessToken);
|
||||||
|
setCopied(true);
|
||||||
|
setTimeout(() => setCopied(false), 2000);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (status === "loading") {
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-center min-h-screen">
|
||||||
|
<div className="text-muted-foreground">Loading...</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-center min-h-screen">
|
||||||
|
<Card className="w-[400px]">
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle>Authentication</CardTitle>
|
||||||
|
<CardDescription>
|
||||||
|
{session
|
||||||
|
? "You are successfully authenticated!"
|
||||||
|
: "Sign in with GitHub to use your access token as an API key"}
|
||||||
|
</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="space-y-4">
|
||||||
|
{!session ? (
|
||||||
|
<Button
|
||||||
|
onClick={() => {
|
||||||
|
console.log("Signing in with GitHub...");
|
||||||
|
signIn("github", { callbackUrl: "/auth/signin" }).catch(
|
||||||
|
(error) => {
|
||||||
|
console.error("Sign in error:", error);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}}
|
||||||
|
className="w-full"
|
||||||
|
variant="default"
|
||||||
|
>
|
||||||
|
<Github className="mr-2 h-4 w-4" />
|
||||||
|
Sign in with GitHub
|
||||||
|
</Button>
|
||||||
|
) : (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="text-sm text-muted-foreground">
|
||||||
|
Signed in as {session.user?.email}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{session.accessToken && (
|
||||||
|
<div className="space-y-2">
|
||||||
|
<div className="text-sm font-medium">
|
||||||
|
GitHub Access Token:
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<code className="flex-1 p-2 bg-muted rounded text-xs break-all">
|
||||||
|
{session.accessToken}
|
||||||
|
</code>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
onClick={handleCopyToken}
|
||||||
|
>
|
||||||
|
{copied ? (
|
||||||
|
<Check className="h-4 w-4" />
|
||||||
|
) : (
|
||||||
|
<Copy className="h-4 w-4" />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-muted-foreground">
|
||||||
|
This GitHub token will be used as your API key for
|
||||||
|
authenticated Llama Stack requests.
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Button onClick={() => router.push("/")} className="flex-1">
|
||||||
|
<Home className="mr-2 h-4 w-4" />
|
||||||
|
Go to Dashboard
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
onClick={() => signOut()}
|
||||||
|
variant="outline"
|
||||||
|
className="flex-1"
|
||||||
|
>
|
||||||
|
Sign out
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
import type { Metadata } from "next";
|
import type { Metadata } from "next";
|
||||||
import { ThemeProvider } from "@/components/ui/theme-provider";
|
import { ThemeProvider } from "@/components/ui/theme-provider";
|
||||||
|
import { SessionProvider } from "@/components/providers/session-provider";
|
||||||
import { Geist, Geist_Mono } from "next/font/google";
|
import { Geist, Geist_Mono } from "next/font/google";
|
||||||
import { ModeToggle } from "@/components/ui/mode-toggle";
|
import { ModeToggle } from "@/components/ui/mode-toggle";
|
||||||
import "./globals.css";
|
import "./globals.css";
|
||||||
|
@ -21,11 +22,13 @@ export const metadata: Metadata = {
|
||||||
|
|
||||||
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
||||||
import { AppSidebar } from "@/components/layout/app-sidebar";
|
import { AppSidebar } from "@/components/layout/app-sidebar";
|
||||||
|
import { SignInButton } from "@/components/ui/sign-in-button";
|
||||||
|
|
||||||
export default function Layout({ children }: { children: React.ReactNode }) {
|
export default function Layout({ children }: { children: React.ReactNode }) {
|
||||||
return (
|
return (
|
||||||
<html lang="en" suppressHydrationWarning>
|
<html lang="en" suppressHydrationWarning>
|
||||||
<body className={`${geistSans.variable} ${geistMono.variable} font-sans`}>
|
<body className={`${geistSans.variable} ${geistMono.variable} font-sans`}>
|
||||||
|
<SessionProvider>
|
||||||
<ThemeProvider
|
<ThemeProvider
|
||||||
attribute="class"
|
attribute="class"
|
||||||
defaultTheme="system"
|
defaultTheme="system"
|
||||||
|
@ -41,7 +44,8 @@ export default function Layout({ children }: { children: React.ReactNode }) {
|
||||||
<SidebarTrigger />
|
<SidebarTrigger />
|
||||||
</div>
|
</div>
|
||||||
<div className="flex-1 text-center"></div>
|
<div className="flex-1 text-center"></div>
|
||||||
<div className="flex-none">
|
<div className="flex-none flex items-center gap-2">
|
||||||
|
<SignInButton />
|
||||||
<ModeToggle />
|
<ModeToggle />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
@ -49,6 +53,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
|
||||||
</main>
|
</main>
|
||||||
</SidebarProvider>
|
</SidebarProvider>
|
||||||
</ThemeProvider>
|
</ThemeProvider>
|
||||||
|
</SessionProvider>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
);
|
);
|
||||||
|
|
|
@ -4,11 +4,12 @@ import { useEffect, useState } from "react";
|
||||||
import { useParams } from "next/navigation";
|
import { useParams } from "next/navigation";
|
||||||
import { ChatCompletion } from "@/lib/types";
|
import { ChatCompletion } from "@/lib/types";
|
||||||
import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail";
|
import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail";
|
||||||
import { client } from "@/lib/client";
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
|
|
||||||
export default function ChatCompletionDetailPage() {
|
export default function ChatCompletionDetailPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const id = params.id as string;
|
const id = params.id as string;
|
||||||
|
const client = useAuthClient();
|
||||||
|
|
||||||
const [completionDetail, setCompletionDetail] =
|
const [completionDetail, setCompletionDetail] =
|
||||||
useState<ChatCompletion | null>(null);
|
useState<ChatCompletion | null>(null);
|
||||||
|
@ -45,7 +46,7 @@ export default function ChatCompletionDetailPage() {
|
||||||
};
|
};
|
||||||
|
|
||||||
fetchCompletionDetail();
|
fetchCompletionDetail();
|
||||||
}, [id]);
|
}, [id, client]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ChatCompletionDetailView
|
<ChatCompletionDetailView
|
||||||
|
|
|
@ -5,11 +5,12 @@ import { useParams } from "next/navigation";
|
||||||
import type { ResponseObject } from "llama-stack-client/resources/responses/responses";
|
import type { ResponseObject } from "llama-stack-client/resources/responses/responses";
|
||||||
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
|
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
|
||||||
import { ResponseDetailView } from "@/components/responses/responses-detail";
|
import { ResponseDetailView } from "@/components/responses/responses-detail";
|
||||||
import { client } from "@/lib/client";
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
|
|
||||||
export default function ResponseDetailPage() {
|
export default function ResponseDetailPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const id = params.id as string;
|
const id = params.id as string;
|
||||||
|
const client = useAuthClient();
|
||||||
|
|
||||||
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
|
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
|
||||||
null,
|
null,
|
||||||
|
@ -109,7 +110,7 @@ export default function ResponseDetailPage() {
|
||||||
};
|
};
|
||||||
|
|
||||||
fetchResponseDetail();
|
fetchResponseDetail();
|
||||||
}, [id]);
|
}, [id, client]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ResponseDetailView
|
<ResponseDetailView
|
||||||
|
|
|
@ -12,24 +12,34 @@ jest.mock("next/navigation", () => ({
|
||||||
}),
|
}),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
// Mock next-auth
|
||||||
|
jest.mock("next-auth/react", () => ({
|
||||||
|
useSession: () => ({
|
||||||
|
status: "authenticated",
|
||||||
|
data: { accessToken: "mock-token" },
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
// Mock helper functions
|
// Mock helper functions
|
||||||
jest.mock("@/lib/truncate-text");
|
jest.mock("@/lib/truncate-text");
|
||||||
jest.mock("@/lib/format-message-content");
|
jest.mock("@/lib/format-message-content");
|
||||||
|
|
||||||
// Mock the client
|
// Mock the auth client hook
|
||||||
jest.mock("@/lib/client", () => ({
|
const mockClient = {
|
||||||
client: {
|
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
list: jest.fn(),
|
list: jest.fn(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
};
|
||||||
|
|
||||||
|
jest.mock("@/hooks/use-auth-client", () => ({
|
||||||
|
useAuthClient: () => mockClient,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Mock the usePagination hook
|
// Mock the usePagination hook
|
||||||
const mockLoadMore = jest.fn();
|
const mockLoadMore = jest.fn();
|
||||||
jest.mock("@/hooks/usePagination", () => ({
|
jest.mock("@/hooks/use-pagination", () => ({
|
||||||
usePagination: jest.fn(() => ({
|
usePagination: jest.fn(() => ({
|
||||||
data: [],
|
data: [],
|
||||||
status: "idle",
|
status: "idle",
|
||||||
|
@ -47,7 +57,7 @@ import {
|
||||||
} from "@/lib/format-message-content";
|
} from "@/lib/format-message-content";
|
||||||
|
|
||||||
// Import the mocked hook
|
// Import the mocked hook
|
||||||
import { usePagination } from "@/hooks/usePagination";
|
import { usePagination } from "@/hooks/use-pagination";
|
||||||
const mockedUsePagination = usePagination as jest.MockedFunction<
|
const mockedUsePagination = usePagination as jest.MockedFunction<
|
||||||
typeof usePagination
|
typeof usePagination
|
||||||
>;
|
>;
|
||||||
|
|
|
@ -10,8 +10,7 @@ import {
|
||||||
extractTextFromContentPart,
|
extractTextFromContentPart,
|
||||||
extractDisplayableText,
|
extractDisplayableText,
|
||||||
} from "@/lib/format-message-content";
|
} from "@/lib/format-message-content";
|
||||||
import { usePagination } from "@/hooks/usePagination";
|
import { usePagination } from "@/hooks/use-pagination";
|
||||||
import { client } from "@/lib/client";
|
|
||||||
|
|
||||||
interface ChatCompletionsTableProps {
|
interface ChatCompletionsTableProps {
|
||||||
/** Optional pagination configuration */
|
/** Optional pagination configuration */
|
||||||
|
@ -32,12 +31,15 @@ function formatChatCompletionToRow(completion: ChatCompletion): LogTableRow {
|
||||||
export function ChatCompletionsTable({
|
export function ChatCompletionsTable({
|
||||||
paginationOptions,
|
paginationOptions,
|
||||||
}: ChatCompletionsTableProps) {
|
}: ChatCompletionsTableProps) {
|
||||||
const fetchFunction = async (params: {
|
const fetchFunction = async (
|
||||||
|
client: ReturnType<typeof import("@/hooks/use-auth-client").useAuthClient>,
|
||||||
|
params: {
|
||||||
after?: string;
|
after?: string;
|
||||||
limit: number;
|
limit: number;
|
||||||
model?: string;
|
model?: string;
|
||||||
order?: string;
|
order?: string;
|
||||||
}) => {
|
},
|
||||||
|
) => {
|
||||||
const response = await client.chat.completions.list({
|
const response = await client.chat.completions.list({
|
||||||
after: params.after,
|
after: params.after,
|
||||||
limit: params.limit,
|
limit: params.limit,
|
||||||
|
|
|
@ -12,7 +12,7 @@ jest.mock("next/navigation", () => ({
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Mock the useInfiniteScroll hook
|
// Mock the useInfiniteScroll hook
|
||||||
jest.mock("@/hooks/useInfiniteScroll", () => ({
|
jest.mock("@/hooks/use-infinite-scroll", () => ({
|
||||||
useInfiniteScroll: jest.fn((onLoadMore, options) => {
|
useInfiniteScroll: jest.fn((onLoadMore, options) => {
|
||||||
const ref = React.useRef(null);
|
const ref = React.useRef(null);
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import { useRouter } from "next/navigation";
|
||||||
import { useRef } from "react";
|
import { useRef } from "react";
|
||||||
import { truncateText } from "@/lib/truncate-text";
|
import { truncateText } from "@/lib/truncate-text";
|
||||||
import { PaginationStatus } from "@/lib/types";
|
import { PaginationStatus } from "@/lib/types";
|
||||||
import { useInfiniteScroll } from "@/hooks/useInfiniteScroll";
|
import { useInfiniteScroll } from "@/hooks/use-infinite-scroll";
|
||||||
import {
|
import {
|
||||||
Table,
|
Table,
|
||||||
TableBody,
|
TableBody,
|
||||||
|
|
7
llama_stack/ui/components/providers/session-provider.tsx
Normal file
7
llama_stack/ui/components/providers/session-provider.tsx
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { SessionProvider as NextAuthSessionProvider } from "next-auth/react";
|
||||||
|
|
||||||
|
export function SessionProvider({ children }: { children: React.ReactNode }) {
|
||||||
|
return <NextAuthSessionProvider>{children}</NextAuthSessionProvider>;
|
||||||
|
}
|
|
@ -12,21 +12,31 @@ jest.mock("next/navigation", () => ({
|
||||||
}),
|
}),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
// Mock next-auth
|
||||||
|
jest.mock("next-auth/react", () => ({
|
||||||
|
useSession: () => ({
|
||||||
|
status: "authenticated",
|
||||||
|
data: { accessToken: "mock-token" },
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
// Mock helper functions
|
// Mock helper functions
|
||||||
jest.mock("@/lib/truncate-text");
|
jest.mock("@/lib/truncate-text");
|
||||||
|
|
||||||
// Mock the client
|
// Mock the auth client hook
|
||||||
jest.mock("@/lib/client", () => ({
|
const mockClient = {
|
||||||
client: {
|
|
||||||
responses: {
|
responses: {
|
||||||
list: jest.fn(),
|
list: jest.fn(),
|
||||||
},
|
},
|
||||||
},
|
};
|
||||||
|
|
||||||
|
jest.mock("@/hooks/use-auth-client", () => ({
|
||||||
|
useAuthClient: () => mockClient,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Mock the usePagination hook
|
// Mock the usePagination hook
|
||||||
const mockLoadMore = jest.fn();
|
const mockLoadMore = jest.fn();
|
||||||
jest.mock("@/hooks/usePagination", () => ({
|
jest.mock("@/hooks/use-pagination", () => ({
|
||||||
usePagination: jest.fn(() => ({
|
usePagination: jest.fn(() => ({
|
||||||
data: [],
|
data: [],
|
||||||
status: "idle",
|
status: "idle",
|
||||||
|
@ -40,7 +50,7 @@ jest.mock("@/hooks/usePagination", () => ({
|
||||||
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
|
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
|
||||||
|
|
||||||
// Import the mocked hook
|
// Import the mocked hook
|
||||||
import { usePagination } from "@/hooks/usePagination";
|
import { usePagination } from "@/hooks/use-pagination";
|
||||||
const mockedUsePagination = usePagination as jest.MockedFunction<
|
const mockedUsePagination = usePagination as jest.MockedFunction<
|
||||||
typeof usePagination
|
typeof usePagination
|
||||||
>;
|
>;
|
||||||
|
|
|
@ -6,8 +6,7 @@ import {
|
||||||
UsePaginationOptions,
|
UsePaginationOptions,
|
||||||
} from "@/lib/types";
|
} from "@/lib/types";
|
||||||
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
|
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
|
||||||
import { usePagination } from "@/hooks/usePagination";
|
import { usePagination } from "@/hooks/use-pagination";
|
||||||
import { client } from "@/lib/client";
|
|
||||||
import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses";
|
import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses";
|
||||||
import {
|
import {
|
||||||
isMessageInput,
|
isMessageInput,
|
||||||
|
@ -125,12 +124,15 @@ function formatResponseToRow(response: OpenAIResponse): LogTableRow {
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ResponsesTable({ paginationOptions }: ResponsesTableProps) {
|
export function ResponsesTable({ paginationOptions }: ResponsesTableProps) {
|
||||||
const fetchFunction = async (params: {
|
const fetchFunction = async (
|
||||||
|
client: ReturnType<typeof import("@/hooks/use-auth-client").useAuthClient>,
|
||||||
|
params: {
|
||||||
after?: string;
|
after?: string;
|
||||||
limit: number;
|
limit: number;
|
||||||
model?: string;
|
model?: string;
|
||||||
order?: string;
|
order?: string;
|
||||||
}) => {
|
},
|
||||||
|
) => {
|
||||||
const response = await client.responses.list({
|
const response = await client.responses.list({
|
||||||
after: params.after,
|
after: params.after,
|
||||||
limit: params.limit,
|
limit: params.limit,
|
||||||
|
|
25
llama_stack/ui/components/ui/sign-in-button.tsx
Normal file
25
llama_stack/ui/components/ui/sign-in-button.tsx
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { User } from "lucide-react";
|
||||||
|
import Link from "next/link";
|
||||||
|
import { useSession } from "next-auth/react";
|
||||||
|
import { Button } from "./button";
|
||||||
|
|
||||||
|
export function SignInButton() {
|
||||||
|
const { data: session, status } = useSession();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Button variant="ghost" size="sm" asChild>
|
||||||
|
<Link href="/auth/signin" className="flex items-center">
|
||||||
|
<User className="mr-2 h-4 w-4" />
|
||||||
|
<span>
|
||||||
|
{status === "loading"
|
||||||
|
? "Loading..."
|
||||||
|
: session
|
||||||
|
? session.user?.email || "Signed In"
|
||||||
|
: "Sign In"}
|
||||||
|
</span>
|
||||||
|
</Link>
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
}
|
24
llama_stack/ui/hooks/use-auth-client.ts
Normal file
24
llama_stack/ui/hooks/use-auth-client.ts
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import { useSession } from "next-auth/react";
|
||||||
|
import { useMemo } from "react";
|
||||||
|
import LlamaStackClient from "llama-stack-client";
|
||||||
|
|
||||||
|
export function useAuthClient() {
|
||||||
|
const { data: session } = useSession();
|
||||||
|
|
||||||
|
const client = useMemo(() => {
|
||||||
|
const clientHostname =
|
||||||
|
typeof window !== "undefined" ? window.location.origin : "";
|
||||||
|
|
||||||
|
const options: any = {
|
||||||
|
baseURL: `${clientHostname}/api`,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (session?.accessToken) {
|
||||||
|
options.apiKey = session.accessToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new LlamaStackClient(options);
|
||||||
|
}, [session?.accessToken]);
|
||||||
|
|
||||||
|
return client;
|
||||||
|
}
|
|
@ -2,6 +2,9 @@
|
||||||
|
|
||||||
import { useState, useCallback, useEffect, useRef } from "react";
|
import { useState, useCallback, useEffect, useRef } from "react";
|
||||||
import { PaginationStatus, UsePaginationOptions } from "@/lib/types";
|
import { PaginationStatus, UsePaginationOptions } from "@/lib/types";
|
||||||
|
import { useSession } from "next-auth/react";
|
||||||
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
|
||||||
interface PaginationState<T> {
|
interface PaginationState<T> {
|
||||||
data: T[];
|
data: T[];
|
||||||
|
@ -28,13 +31,18 @@ export interface PaginationReturn<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
interface UsePaginationParams<T> extends UsePaginationOptions {
|
interface UsePaginationParams<T> extends UsePaginationOptions {
|
||||||
fetchFunction: (params: {
|
fetchFunction: (
|
||||||
|
client: ReturnType<typeof useAuthClient>,
|
||||||
|
params: {
|
||||||
after?: string;
|
after?: string;
|
||||||
limit: number;
|
limit: number;
|
||||||
model?: string;
|
model?: string;
|
||||||
order?: string;
|
order?: string;
|
||||||
}) => Promise<PaginationResponse<T>>;
|
},
|
||||||
|
) => Promise<PaginationResponse<T>>;
|
||||||
errorMessagePrefix: string;
|
errorMessagePrefix: string;
|
||||||
|
enabled?: boolean;
|
||||||
|
useAuth?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function usePagination<T>({
|
export function usePagination<T>({
|
||||||
|
@ -43,7 +51,12 @@ export function usePagination<T>({
|
||||||
order = "desc",
|
order = "desc",
|
||||||
fetchFunction,
|
fetchFunction,
|
||||||
errorMessagePrefix,
|
errorMessagePrefix,
|
||||||
|
enabled = true,
|
||||||
|
useAuth = true,
|
||||||
}: UsePaginationParams<T>): PaginationReturn<T> {
|
}: UsePaginationParams<T>): PaginationReturn<T> {
|
||||||
|
const { status: sessionStatus } = useSession();
|
||||||
|
const client = useAuthClient();
|
||||||
|
const router = useRouter();
|
||||||
const [state, setState] = useState<PaginationState<T>>({
|
const [state, setState] = useState<PaginationState<T>>({
|
||||||
data: [],
|
data: [],
|
||||||
status: "loading",
|
status: "loading",
|
||||||
|
@ -74,7 +87,7 @@ export function usePagination<T>({
|
||||||
error: null,
|
error: null,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
const response = await fetchFunction({
|
const response = await fetchFunction(client, {
|
||||||
after: after || undefined,
|
after: after || undefined,
|
||||||
limit: fetchLimit,
|
limit: fetchLimit,
|
||||||
...(model && { model }),
|
...(model && { model }),
|
||||||
|
@ -91,6 +104,17 @@ export function usePagination<T>({
|
||||||
status: "idle",
|
status: "idle",
|
||||||
}));
|
}));
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
// Check if it's a 401 unauthorized error
|
||||||
|
if (
|
||||||
|
err &&
|
||||||
|
typeof err === "object" &&
|
||||||
|
"status" in err &&
|
||||||
|
err.status === 401
|
||||||
|
) {
|
||||||
|
router.push("/auth/signin");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const errorMessage = isInitialLoad
|
const errorMessage = isInitialLoad
|
||||||
? `Failed to load ${errorMessagePrefix}. Please try refreshing the page.`
|
? `Failed to load ${errorMessagePrefix}. Please try refreshing the page.`
|
||||||
: `Failed to load more ${errorMessagePrefix}. Please try again.`;
|
: `Failed to load more ${errorMessagePrefix}. Please try again.`;
|
||||||
|
@ -107,7 +131,7 @@ export function usePagination<T>({
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[limit, model, order, fetchFunction, errorMessagePrefix],
|
[limit, model, order, fetchFunction, errorMessagePrefix, client, router],
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -120,17 +144,28 @@ export function usePagination<T>({
|
||||||
}
|
}
|
||||||
}, [fetchData]);
|
}, [fetchData]);
|
||||||
|
|
||||||
// Auto-load initial data on mount
|
// Auto-load initial data on mount when enabled
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!hasFetchedInitialData.current) {
|
// If using auth, wait for session to load
|
||||||
|
const isAuthReady = !useAuth || sessionStatus !== "loading";
|
||||||
|
const shouldFetch = enabled && isAuthReady;
|
||||||
|
|
||||||
|
if (shouldFetch && !hasFetchedInitialData.current) {
|
||||||
hasFetchedInitialData.current = true;
|
hasFetchedInitialData.current = true;
|
||||||
fetchData();
|
fetchData();
|
||||||
|
} else if (!shouldFetch) {
|
||||||
|
// Reset the flag when disabled so it can fetch when re-enabled
|
||||||
|
hasFetchedInitialData.current = false;
|
||||||
}
|
}
|
||||||
}, [fetchData]);
|
}, [fetchData, enabled, useAuth, sessionStatus]);
|
||||||
|
|
||||||
|
// Override status if we're waiting for auth
|
||||||
|
const effectiveStatus =
|
||||||
|
useAuth && sessionStatus === "loading" ? "loading" : state.status;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
data: state.data,
|
data: state.data,
|
||||||
status: state.status,
|
status: effectiveStatus,
|
||||||
hasMore: state.hasMore,
|
hasMore: state.hasMore,
|
||||||
error: state.error,
|
error: state.error,
|
||||||
loadMore,
|
loadMore,
|
11
llama_stack/ui/instrumentation.ts
Normal file
11
llama_stack/ui/instrumentation.ts
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
/**
|
||||||
|
* Next.js Instrumentation
|
||||||
|
* This file is used for initializing monitoring, tracing, or other observability tools.
|
||||||
|
* It runs once when the server starts, before any application code.
|
||||||
|
*
|
||||||
|
* Learn more: https://nextjs.org/docs/app/building-your-application/optimizing/instrumentation
|
||||||
|
*/
|
||||||
|
|
||||||
|
export async function register() {
|
||||||
|
await import("./lib/config-validator");
|
||||||
|
}
|
38
llama_stack/ui/lib/auth.ts
Normal file
38
llama_stack/ui/lib/auth.ts
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import { NextAuthOptions } from "next-auth";
|
||||||
|
import GithubProvider from "next-auth/providers/github";
|
||||||
|
|
||||||
|
export const authOptions: NextAuthOptions = {
|
||||||
|
providers: [
|
||||||
|
GithubProvider({
|
||||||
|
clientId: process.env.GITHUB_CLIENT_ID!,
|
||||||
|
clientSecret: process.env.GITHUB_CLIENT_SECRET!,
|
||||||
|
authorization: {
|
||||||
|
params: {
|
||||||
|
scope: "read:user user:email",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
debug: process.env.NODE_ENV === "development",
|
||||||
|
callbacks: {
|
||||||
|
async jwt({ token, account }) {
|
||||||
|
// Persist the OAuth access_token to the token right after signin
|
||||||
|
if (account) {
|
||||||
|
token.accessToken = account.access_token;
|
||||||
|
}
|
||||||
|
return token;
|
||||||
|
},
|
||||||
|
async session({ session, token }) {
|
||||||
|
// Send properties to the client, like an access_token from a provider.
|
||||||
|
session.accessToken = token.accessToken as string;
|
||||||
|
return session;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
pages: {
|
||||||
|
signIn: "/auth/signin",
|
||||||
|
error: "/auth/signin", // Redirect errors to our custom page
|
||||||
|
},
|
||||||
|
session: {
|
||||||
|
strategy: "jwt",
|
||||||
|
},
|
||||||
|
};
|
|
@ -1,6 +0,0 @@
|
||||||
import LlamaStackClient from "llama-stack-client";
|
|
||||||
|
|
||||||
export const client = new LlamaStackClient({
|
|
||||||
baseURL:
|
|
||||||
typeof window !== "undefined" ? `${window.location.origin}/api` : "/api",
|
|
||||||
});
|
|
56
llama_stack/ui/lib/config-validator.ts
Normal file
56
llama_stack/ui/lib/config-validator.ts
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
/**
|
||||||
|
* Validates environment configuration for the application
|
||||||
|
* This is called during server initialization
|
||||||
|
*/
|
||||||
|
export function validateServerConfig() {
|
||||||
|
if (process.env.NODE_ENV === "development") {
|
||||||
|
console.log("🚀 Starting Llama Stack UI Server...");
|
||||||
|
|
||||||
|
// Check optional configurations
|
||||||
|
const optionalConfigs = {
|
||||||
|
NEXTAUTH_URL: process.env.NEXTAUTH_URL || "http://localhost:8322",
|
||||||
|
LLAMA_STACK_BACKEND_URL:
|
||||||
|
process.env.LLAMA_STACK_BACKEND_URL || "http://localhost:8321",
|
||||||
|
LLAMA_STACK_UI_PORT: process.env.LLAMA_STACK_UI_PORT || "8322",
|
||||||
|
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||||
|
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||||
|
};
|
||||||
|
|
||||||
|
console.log("\n📋 Configuration:");
|
||||||
|
console.log(` - NextAuth URL: ${optionalConfigs.NEXTAUTH_URL}`);
|
||||||
|
console.log(` - Backend URL: ${optionalConfigs.LLAMA_STACK_BACKEND_URL}`);
|
||||||
|
console.log(` - UI Port: ${optionalConfigs.LLAMA_STACK_UI_PORT}`);
|
||||||
|
|
||||||
|
// Check GitHub OAuth configuration
|
||||||
|
if (
|
||||||
|
!optionalConfigs.GITHUB_CLIENT_ID ||
|
||||||
|
!optionalConfigs.GITHUB_CLIENT_SECRET
|
||||||
|
) {
|
||||||
|
console.log(
|
||||||
|
"\n📝 GitHub OAuth not configured (authentication features disabled)",
|
||||||
|
);
|
||||||
|
console.log(" To enable GitHub OAuth:");
|
||||||
|
console.log(" 1. Go to https://github.com/settings/applications/new");
|
||||||
|
console.log(
|
||||||
|
" 2. Set Application name: Llama Stack UI (or your preferred name)",
|
||||||
|
);
|
||||||
|
console.log(" 3. Set Homepage URL: http://localhost:8322");
|
||||||
|
console.log(
|
||||||
|
" 4. Set Authorization callback URL: http://localhost:8322/api/auth/callback/github",
|
||||||
|
);
|
||||||
|
console.log(
|
||||||
|
" 5. Create the app and copy the Client ID and Client Secret",
|
||||||
|
);
|
||||||
|
console.log(" 6. Add them to your .env.local file:");
|
||||||
|
console.log(" GITHUB_CLIENT_ID=your_client_id");
|
||||||
|
console.log(" GITHUB_CLIENT_SECRET=your_client_secret");
|
||||||
|
} else {
|
||||||
|
console.log(" - GitHub OAuth: ✅ Configured");
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call this function when the module is imported
|
||||||
|
validateServerConfig();
|
147
llama_stack/ui/package-lock.json
generated
147
llama_stack/ui/package-lock.json
generated
|
@ -18,6 +18,7 @@
|
||||||
"llama-stack-client": "0.2.13",
|
"llama-stack-client": "0.2.13",
|
||||||
"lucide-react": "^0.510.0",
|
"lucide-react": "^0.510.0",
|
||||||
"next": "15.3.3",
|
"next": "15.3.3",
|
||||||
|
"next-auth": "^4.24.11",
|
||||||
"next-themes": "^0.4.6",
|
"next-themes": "^0.4.6",
|
||||||
"react": "^19.0.0",
|
"react": "^19.0.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.0.0",
|
||||||
|
@ -548,7 +549,6 @@
|
||||||
"version": "7.27.1",
|
"version": "7.27.1",
|
||||||
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.1.tgz",
|
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.1.tgz",
|
||||||
"integrity": "sha512-1x3D2xEk2fRo3PAhwQwu5UubzgiVWSXTBfWpVd2Mx2AzRqJuDJCsgaDVZ7HB5iGzDW1Hl1sWN2mFyKjmR9uAog==",
|
"integrity": "sha512-1x3D2xEk2fRo3PAhwQwu5UubzgiVWSXTBfWpVd2Mx2AzRqJuDJCsgaDVZ7HB5iGzDW1Hl1sWN2mFyKjmR9uAog==",
|
||||||
"dev": true,
|
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=6.9.0"
|
"node": ">=6.9.0"
|
||||||
|
@ -2423,6 +2423,15 @@
|
||||||
"node": ">=12.4.0"
|
"node": ">=12.4.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@panva/hkdf": {
|
||||||
|
"version": "1.2.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@panva/hkdf/-/hkdf-1.2.1.tgz",
|
||||||
|
"integrity": "sha512-6oclG6Y3PiDFcoyk8srjLfVKyMfVCKJ27JwNPViuXziFpmdz+MZnZN/aKY0JGXgYuO/VghU0jcOAZgWXZ1Dmrw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"funding": {
|
||||||
|
"url": "https://github.com/sponsors/panva"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@pkgr/core": {
|
"node_modules/@pkgr/core": {
|
||||||
"version": "0.2.4",
|
"version": "0.2.4",
|
||||||
"resolved": "https://registry.npmjs.org/@pkgr/core/-/core-0.2.4.tgz",
|
"resolved": "https://registry.npmjs.org/@pkgr/core/-/core-0.2.4.tgz",
|
||||||
|
@ -5279,7 +5288,6 @@
|
||||||
"version": "0.7.2",
|
"version": "0.7.2",
|
||||||
"resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz",
|
"resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz",
|
||||||
"integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==",
|
"integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==",
|
||||||
"dev": true,
|
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">= 0.6"
|
"node": ">= 0.6"
|
||||||
|
@ -9036,6 +9044,15 @@
|
||||||
"jiti": "lib/jiti-cli.mjs"
|
"jiti": "lib/jiti-cli.mjs"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/jose": {
|
||||||
|
"version": "4.15.9",
|
||||||
|
"resolved": "https://registry.npmjs.org/jose/-/jose-4.15.9.tgz",
|
||||||
|
"integrity": "sha512-1vUQX+IdDMVPj4k8kOxgUqlcK518yluMuGZwqlr44FS1ppZB/5GWh4rZG89erpOBOJjU/OBsnCVFfapsRz6nEA==",
|
||||||
|
"license": "MIT",
|
||||||
|
"funding": {
|
||||||
|
"url": "https://github.com/sponsors/panva"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/js-tokens": {
|
"node_modules/js-tokens": {
|
||||||
"version": "4.0.0",
|
"version": "4.0.0",
|
||||||
"resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz",
|
"resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz",
|
||||||
|
@ -9949,6 +9966,38 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/next-auth": {
|
||||||
|
"version": "4.24.11",
|
||||||
|
"resolved": "https://registry.npmjs.org/next-auth/-/next-auth-4.24.11.tgz",
|
||||||
|
"integrity": "sha512-pCFXzIDQX7xmHFs4KVH4luCjaCbuPRtZ9oBUjUhOk84mZ9WVPf94n87TxYI4rSRf9HmfHEF8Yep3JrYDVOo3Cw==",
|
||||||
|
"license": "ISC",
|
||||||
|
"dependencies": {
|
||||||
|
"@babel/runtime": "^7.20.13",
|
||||||
|
"@panva/hkdf": "^1.0.2",
|
||||||
|
"cookie": "^0.7.0",
|
||||||
|
"jose": "^4.15.5",
|
||||||
|
"oauth": "^0.9.15",
|
||||||
|
"openid-client": "^5.4.0",
|
||||||
|
"preact": "^10.6.3",
|
||||||
|
"preact-render-to-string": "^5.1.19",
|
||||||
|
"uuid": "^8.3.2"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"@auth/core": "0.34.2",
|
||||||
|
"next": "^12.2.5 || ^13 || ^14 || ^15",
|
||||||
|
"nodemailer": "^6.6.5",
|
||||||
|
"react": "^17.0.2 || ^18 || ^19",
|
||||||
|
"react-dom": "^17.0.2 || ^18 || ^19"
|
||||||
|
},
|
||||||
|
"peerDependenciesMeta": {
|
||||||
|
"@auth/core": {
|
||||||
|
"optional": true
|
||||||
|
},
|
||||||
|
"nodemailer": {
|
||||||
|
"optional": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/next-themes": {
|
"node_modules/next-themes": {
|
||||||
"version": "0.4.6",
|
"version": "0.4.6",
|
||||||
"resolved": "https://registry.npmjs.org/next-themes/-/next-themes-0.4.6.tgz",
|
"resolved": "https://registry.npmjs.org/next-themes/-/next-themes-0.4.6.tgz",
|
||||||
|
@ -10071,6 +10120,12 @@
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/oauth": {
|
||||||
|
"version": "0.9.15",
|
||||||
|
"resolved": "https://registry.npmjs.org/oauth/-/oauth-0.9.15.tgz",
|
||||||
|
"integrity": "sha512-a5ERWK1kh38ExDEfoO6qUHJb32rd7aYmPHuyCu3Fta/cnICvYmgd2uhuKXvPD+PXB+gCEYYEaQdIRAjCOwAKNA==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/object-assign": {
|
"node_modules/object-assign": {
|
||||||
"version": "4.1.1",
|
"version": "4.1.1",
|
||||||
"resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz",
|
"resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz",
|
||||||
|
@ -10081,6 +10136,15 @@
|
||||||
"node": ">=0.10.0"
|
"node": ">=0.10.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/object-hash": {
|
||||||
|
"version": "2.2.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz",
|
||||||
|
"integrity": "sha512-gScRMn0bS5fH+IuwyIFgnh9zBdo4DV+6GhygmWM9HyNJSgS0hScp1f5vjtm7oIIOiT9trXrShAkLFSc2IqKNgw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 6"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/object-inspect": {
|
"node_modules/object-inspect": {
|
||||||
"version": "1.13.4",
|
"version": "1.13.4",
|
||||||
"resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz",
|
"resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz",
|
||||||
|
@ -10194,6 +10258,15 @@
|
||||||
"url": "https://github.com/sponsors/ljharb"
|
"url": "https://github.com/sponsors/ljharb"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/oidc-token-hash": {
|
||||||
|
"version": "5.1.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/oidc-token-hash/-/oidc-token-hash-5.1.0.tgz",
|
||||||
|
"integrity": "sha512-y0W+X7Ppo7oZX6eovsRkuzcSM40Bicg2JEJkDJ4irIt1wsYAP5MLSNv+QAogO8xivMffw/9OvV3um1pxXgt1uA==",
|
||||||
|
"license": "MIT",
|
||||||
|
"engines": {
|
||||||
|
"node": "^10.13.0 || >=12.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/on-finished": {
|
"node_modules/on-finished": {
|
||||||
"version": "2.4.1",
|
"version": "2.4.1",
|
||||||
"resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz",
|
"resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz",
|
||||||
|
@ -10233,6 +10306,39 @@
|
||||||
"url": "https://github.com/sponsors/sindresorhus"
|
"url": "https://github.com/sponsors/sindresorhus"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/openid-client": {
|
||||||
|
"version": "5.7.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.7.1.tgz",
|
||||||
|
"integrity": "sha512-jDBPgSVfTnkIh71Hg9pRvtJc6wTwqjRkN88+gCFtYWrlP4Yx2Dsrow8uPi3qLr/aeymPF3o2+dS+wOpglK04ew==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"jose": "^4.15.9",
|
||||||
|
"lru-cache": "^6.0.0",
|
||||||
|
"object-hash": "^2.2.0",
|
||||||
|
"oidc-token-hash": "^5.0.3"
|
||||||
|
},
|
||||||
|
"funding": {
|
||||||
|
"url": "https://github.com/sponsors/panva"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/openid-client/node_modules/lru-cache": {
|
||||||
|
"version": "6.0.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz",
|
||||||
|
"integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==",
|
||||||
|
"license": "ISC",
|
||||||
|
"dependencies": {
|
||||||
|
"yallist": "^4.0.0"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/openid-client/node_modules/yallist": {
|
||||||
|
"version": "4.0.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz",
|
||||||
|
"integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==",
|
||||||
|
"license": "ISC"
|
||||||
|
},
|
||||||
"node_modules/optionator": {
|
"node_modules/optionator": {
|
||||||
"version": "0.9.4",
|
"version": "0.9.4",
|
||||||
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
|
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
|
||||||
|
@ -10560,6 +10666,34 @@
|
||||||
"node": "^10 || ^12 || >=14"
|
"node": "^10 || ^12 || >=14"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/preact": {
|
||||||
|
"version": "10.26.9",
|
||||||
|
"resolved": "https://registry.npmjs.org/preact/-/preact-10.26.9.tgz",
|
||||||
|
"integrity": "sha512-SSjF9vcnF27mJK1XyFMNJzFd5u3pQiATFqoaDy03XuN00u4ziveVVEGt5RKJrDR8MHE/wJo9Nnad56RLzS2RMA==",
|
||||||
|
"license": "MIT",
|
||||||
|
"funding": {
|
||||||
|
"type": "opencollective",
|
||||||
|
"url": "https://opencollective.com/preact"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/preact-render-to-string": {
|
||||||
|
"version": "5.2.6",
|
||||||
|
"resolved": "https://registry.npmjs.org/preact-render-to-string/-/preact-render-to-string-5.2.6.tgz",
|
||||||
|
"integrity": "sha512-JyhErpYOvBV1hEPwIxc/fHWXPfnEGdRKxc8gFdAZ7XV4tlzyzG847XAyEZqoDnynP88akM4eaHcSOzNcLWFguw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"pretty-format": "^3.8.0"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"preact": ">=10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/preact-render-to-string/node_modules/pretty-format": {
|
||||||
|
"version": "3.8.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-3.8.0.tgz",
|
||||||
|
"integrity": "sha512-WuxUnVtlWL1OfZFQFuqvnvs6MiAGk9UNsBostyBOB0Is9wb5uRESevA6rnl/rkksXaGX3GzZhPup5d6Vp1nFew==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/prelude-ls": {
|
"node_modules/prelude-ls": {
|
||||||
"version": "1.2.1",
|
"version": "1.2.1",
|
||||||
"resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz",
|
"resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz",
|
||||||
|
@ -12409,6 +12543,15 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/uuid": {
|
||||||
|
"version": "8.3.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz",
|
||||||
|
"integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==",
|
||||||
|
"license": "MIT",
|
||||||
|
"bin": {
|
||||||
|
"uuid": "dist/bin/uuid"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/v8-compile-cache-lib": {
|
"node_modules/v8-compile-cache-lib": {
|
||||||
"version": "3.0.1",
|
"version": "3.0.1",
|
||||||
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
|
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
"llama-stack-client": "0.2.13",
|
"llama-stack-client": "0.2.13",
|
||||||
"lucide-react": "^0.510.0",
|
"lucide-react": "^0.510.0",
|
||||||
"next": "15.3.3",
|
"next": "15.3.3",
|
||||||
|
"next-auth": "^4.24.11",
|
||||||
"next-themes": "^0.4.6",
|
"next-themes": "^0.4.6",
|
||||||
"react": "^19.0.0",
|
"react": "^19.0.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.0.0",
|
||||||
|
|
7
llama_stack/ui/types/next-auth.d.ts
vendored
Normal file
7
llama_stack/ui/types/next-auth.d.ts
vendored
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
import NextAuth from "next-auth";
|
||||||
|
|
||||||
|
declare module "next-auth" {
|
||||||
|
interface Session {
|
||||||
|
accessToken?: string;
|
||||||
|
}
|
||||||
|
}
|
|
@ -86,6 +86,7 @@ unit = [
|
||||||
"sqlalchemy[asyncio]>=2.0.41",
|
"sqlalchemy[asyncio]>=2.0.41",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
|
"pymilvus>=2.5.12",
|
||||||
]
|
]
|
||||||
# These are the core dependencies required for running integration tests. They are shared across all
|
# These are the core dependencies required for running integration tests. They are shared across all
|
||||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||||
|
@ -106,6 +107,7 @@ test = [
|
||||||
"sqlalchemy",
|
"sqlalchemy",
|
||||||
"sqlalchemy[asyncio]>=2.0.41",
|
"sqlalchemy[asyncio]>=2.0.41",
|
||||||
"requests",
|
"requests",
|
||||||
|
"pymilvus>=2.5.12",
|
||||||
]
|
]
|
||||||
docs = [
|
docs = [
|
||||||
"setuptools",
|
"setuptools",
|
||||||
|
|
5
tests/integration/providers/utils/__init__.py
Normal file
5
tests/integration/providers/utils/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
5
tests/integration/providers/utils/sqlstore/__init__.py
Normal file
5
tests/integration/providers/utils/sqlstore/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,173 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.distribution.access_control.access_control import default_policy
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||||
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
def get_postgres_config():
|
||||||
|
"""Get PostgreSQL configuration if tests are enabled."""
|
||||||
|
return PostgresSqlStoreConfig(
|
||||||
|
host=os.environ.get("POSTGRES_HOST", "localhost"),
|
||||||
|
port=int(os.environ.get("POSTGRES_PORT", "5432")),
|
||||||
|
db=os.environ.get("POSTGRES_DB", "llamastack"),
|
||||||
|
user=os.environ.get("POSTGRES_USER", "llamastack"),
|
||||||
|
password=os.environ.get("POSTGRES_PASSWORD", "llamastack"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sqlite_config():
|
||||||
|
"""Get SQLite configuration with temporary database."""
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||||
|
tmp_file.close()
|
||||||
|
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"backend_config",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
("postgres", get_postgres_config),
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
not os.environ.get("ENABLE_POSTGRES_TESTS"),
|
||||||
|
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
|
||||||
|
),
|
||||||
|
id="postgres",
|
||||||
|
),
|
||||||
|
pytest.param(("sqlite", get_sqlite_config), id="sqlite"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
|
async def test_json_comparison(mock_get_authenticated_user, backend_config):
|
||||||
|
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
|
||||||
|
backend_name, config_func = backend_config
|
||||||
|
|
||||||
|
# Handle different config types
|
||||||
|
if backend_name == "postgres":
|
||||||
|
config = config_func()
|
||||||
|
cleanup_path = None
|
||||||
|
else: # sqlite
|
||||||
|
config, cleanup_path = config_func()
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_sqlstore = SqlAlchemySqlStoreImpl(config)
|
||||||
|
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
||||||
|
|
||||||
|
# Create test table
|
||||||
|
table_name = f"test_json_comparison_{backend_name}"
|
||||||
|
await authorized_store.create_table(
|
||||||
|
table=table_name,
|
||||||
|
schema={
|
||||||
|
"id": ColumnType.STRING,
|
||||||
|
"data": ColumnType.STRING,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test with no authenticated user (should handle JSON null comparison)
|
||||||
|
mock_get_authenticated_user.return_value = None
|
||||||
|
|
||||||
|
# Insert some test data
|
||||||
|
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
||||||
|
|
||||||
|
# Test fetching with no user - should not error on JSON comparison
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0]["id"] == "1"
|
||||||
|
assert result.data[0]["access_attributes"] is None
|
||||||
|
|
||||||
|
# Test with authenticated user
|
||||||
|
test_user = User("test-user", {"roles": ["admin"]})
|
||||||
|
mock_get_authenticated_user.return_value = test_user
|
||||||
|
|
||||||
|
# Insert data with user attributes
|
||||||
|
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
||||||
|
|
||||||
|
# Fetch all - admin should see both
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
assert len(result.data) == 2
|
||||||
|
|
||||||
|
# Test with non-admin user
|
||||||
|
regular_user = User("regular-user", {"roles": ["user"]})
|
||||||
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
|
# Should only see public record
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0]["id"] == "1"
|
||||||
|
|
||||||
|
# Test the category missing branch: user with multiple attributes
|
||||||
|
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
|
||||||
|
mock_get_authenticated_user.return_value = multi_user
|
||||||
|
|
||||||
|
# Insert record with multi-user (has both roles and teams)
|
||||||
|
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
|
||||||
|
|
||||||
|
# Test different user types to create records with different attribute patterns
|
||||||
|
# Record with only roles (teams category will be missing)
|
||||||
|
roles_only_user = User("roles-user", {"roles": ["admin"]})
|
||||||
|
mock_get_authenticated_user.return_value = roles_only_user
|
||||||
|
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
|
||||||
|
|
||||||
|
# Record with only teams (roles category will be missing)
|
||||||
|
teams_only_user = User("teams-user", {"teams": ["dev"]})
|
||||||
|
mock_get_authenticated_user.return_value = teams_only_user
|
||||||
|
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
|
||||||
|
|
||||||
|
# Record with different roles/teams (shouldn't match our test user)
|
||||||
|
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
|
||||||
|
mock_get_authenticated_user.return_value = different_user
|
||||||
|
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
|
||||||
|
|
||||||
|
# Now test with the multi-user who has both roles=admin and teams=dev
|
||||||
|
mock_get_authenticated_user.return_value = multi_user
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
|
||||||
|
# Should see:
|
||||||
|
# - public record (1) - no access_attributes
|
||||||
|
# - admin record (2) - user matches roles=admin, teams missing (allowed)
|
||||||
|
# - multi_user record (3) - user matches both roles=admin and teams=dev
|
||||||
|
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
|
||||||
|
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
|
||||||
|
# Should NOT see:
|
||||||
|
# - different_user record (6) - user doesn't match roles=user or teams=qa
|
||||||
|
expected_ids = {"1", "2", "3", "4", "5"}
|
||||||
|
actual_ids = {record["id"] for record in result.data}
|
||||||
|
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
|
||||||
|
|
||||||
|
# Verify the category missing logic specifically
|
||||||
|
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
|
||||||
|
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
|
||||||
|
assert category_test_ids == {"4", "5"}, (
|
||||||
|
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up records
|
||||||
|
for record_id in ["1", "2", "3", "4", "5", "6"]:
|
||||||
|
try:
|
||||||
|
await base_sqlstore.delete(table_name, {"id": record_id})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up temporary SQLite database file if needed
|
||||||
|
if cleanup_path:
|
||||||
|
try:
|
||||||
|
os.unlink(cleanup_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
|
@ -0,0 +1,420 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from pymilvus import Collection, MilvusClient, connections
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||||
|
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX, MilvusIndex, MilvusVectorIOAdapter
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
|
# TODO: Refactor these to be for inline vector-io providers
|
||||||
|
MILVUS_ALIAS = "test_milvus"
|
||||||
|
COLLECTION_PREFIX = "test_collection"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def loop():
|
||||||
|
return asyncio.new_event_loop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def mock_inference_api(embedding_dimension):
|
||||||
|
class MockInferenceAPI:
|
||||||
|
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
|
||||||
|
|
||||||
|
return MockInferenceAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def unique_kvstore_config(tmp_path_factory):
|
||||||
|
# Generate a unique filename for this test
|
||||||
|
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
||||||
|
temp_dir = tmp_path_factory.getbasetemp()
|
||||||
|
db_path = str(temp_dir / f"{unique_id}.db")
|
||||||
|
|
||||||
|
return SqliteKVStoreConfig(db_path=db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||||
|
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
|
||||||
|
temp_dir = tmp_path_factory.getbasetemp()
|
||||||
|
db_path = str(temp_dir / "test_milvus.db")
|
||||||
|
client = MilvusClient(db_path)
|
||||||
|
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||||
|
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
|
||||||
|
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||||
|
index.db_path = db_path
|
||||||
|
yield index
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def milvus_vec_adapter(milvus_vec_index, mock_inference_api):
|
||||||
|
config = MilvusVectorIOConfig(
|
||||||
|
db_path=milvus_vec_index.db_path,
|
||||||
|
kvstore=SqliteKVStoreConfig(),
|
||||||
|
)
|
||||||
|
adapter = MilvusVectorIOAdapter(
|
||||||
|
config=config,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
await adapter.initialize()
|
||||||
|
await adapter.register_vector_db(
|
||||||
|
VectorDB(
|
||||||
|
identifier=adapter.metadata_collection_name,
|
||||||
|
provider_id="test_provider",
|
||||||
|
embedding_model="test_model",
|
||||||
|
embedding_dimension=128,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield adapter
|
||||||
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_contains_initial_collection(milvus_vec_adapter):
|
||||||
|
coll_name = milvus_vec_adapter.metadata_collection_name
|
||||||
|
assert coll_name in milvus_vec_adapter.cache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_chunks(milvus_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
resp = await milvus_vec_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
|
||||||
|
assert resp.chunks[0].content == sample_chunks[0].content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_vector(milvus_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
|
||||||
|
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
query_emb = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
resp = await milvus_vec_index.query_vector(query_emb, k=2, score_threshold=0.0)
|
||||||
|
assert isinstance(resp, QueryChunksResponse)
|
||||||
|
assert len(resp.chunks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chunk_id_conflict(milvus_vec_index, sample_chunks, embedding_dimension):
|
||||||
|
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
|
||||||
|
await milvus_vec_index.add_chunks(sample_chunks, embeddings)
|
||||||
|
coll = Collection(milvus_vec_index.collection_name, using=MILVUS_ALIAS)
|
||||||
|
ids = coll.query(expr="id >= 0", output_fields=["id"], timeout=30)
|
||||||
|
flat_ids = [i["id"] for i in ids]
|
||||||
|
assert len(flat_ids) == len(set(flat_ids))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_config):
|
||||||
|
kvstore = await kvstore_impl(unique_kvstore_config)
|
||||||
|
vector_db = VectorDB(
|
||||||
|
identifier="test_db",
|
||||||
|
provider_id="test_provider",
|
||||||
|
embedding_model="test_model",
|
||||||
|
embedding_dimension=128,
|
||||||
|
metadata={"test_key": "test_value"},
|
||||||
|
)
|
||||||
|
test_vector_db_data = vector_db.model_dump_json()
|
||||||
|
await kvstore.set(f"{VECTOR_DBS_PREFIX}test_db", test_vector_db_data)
|
||||||
|
tmp_milvus_vec_adapter = MilvusVectorIOAdapter(
|
||||||
|
config=MilvusVectorIOConfig(
|
||||||
|
db_path=milvus_vec_index.db_path,
|
||||||
|
kvstore=unique_kvstore_config,
|
||||||
|
),
|
||||||
|
inference_api=None,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
await tmp_milvus_vec_adapter.initialize()
|
||||||
|
|
||||||
|
vector_db = VectorDB(
|
||||||
|
identifier="test_db",
|
||||||
|
provider_id="test_provider",
|
||||||
|
embedding_model="test_model",
|
||||||
|
embedding_dimension=128,
|
||||||
|
)
|
||||||
|
test_vector_db_data = vector_db.model_dump_json()
|
||||||
|
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
|
||||||
|
|
||||||
|
assert milvus_vec_index.client is not None
|
||||||
|
assert isinstance(milvus_vec_index.client, MilvusClient)
|
||||||
|
assert tmp_milvus_vec_adapter.cache is not None
|
||||||
|
# registering a vector won't update the cache or openai_vector_store collection name
|
||||||
|
assert (
|
||||||
|
tmp_milvus_vec_adapter.metadata_collection_name not in tmp_milvus_vec_adapter.cache
|
||||||
|
or tmp_milvus_vec_adapter.openai_vector_stores
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_persistence_across_adapter_restarts(
|
||||||
|
tmp_path, milvus_vec_index, mock_inference_api, unique_kvstore_config
|
||||||
|
):
|
||||||
|
adapter1 = MilvusVectorIOAdapter(
|
||||||
|
config=MilvusVectorIOConfig(db_path=milvus_vec_index.db_path, kvstore=unique_kvstore_config),
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
await adapter1.initialize()
|
||||||
|
dummy = VectorDB(
|
||||||
|
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||||
|
)
|
||||||
|
await adapter1.register_vector_db(dummy)
|
||||||
|
await adapter1.shutdown()
|
||||||
|
|
||||||
|
await adapter1.initialize()
|
||||||
|
assert "foo_db" in adapter1.cache
|
||||||
|
await adapter1.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_and_unregister_vector_db(milvus_vec_adapter):
|
||||||
|
try:
|
||||||
|
connections.disconnect(MILVUS_ALIAS)
|
||||||
|
except Exception as _:
|
||||||
|
pass
|
||||||
|
|
||||||
|
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_adapter.config.db_path)
|
||||||
|
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||||
|
dummy = VectorDB(
|
||||||
|
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||||
|
)
|
||||||
|
|
||||||
|
await milvus_vec_adapter.register_vector_db(dummy)
|
||||||
|
assert dummy.identifier in milvus_vec_adapter.cache
|
||||||
|
|
||||||
|
if dummy.identifier in milvus_vec_adapter.cache:
|
||||||
|
index = milvus_vec_adapter.cache[dummy.identifier].index
|
||||||
|
if hasattr(index, "client") and hasattr(index.client, "_using"):
|
||||||
|
index.client._using = MILVUS_ALIAS
|
||||||
|
|
||||||
|
await milvus_vec_adapter.unregister_vector_db(dummy.identifier)
|
||||||
|
assert dummy.identifier not in milvus_vec_adapter.cache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_unregistered_raises(milvus_vec_adapter):
|
||||||
|
fake_emb = np.zeros(8, dtype=np.float32)
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
await milvus_vec_adapter.query_chunks("no_such_db", fake_emb)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_chunks_calls_underlying_index(milvus_vec_adapter):
|
||||||
|
fake_index = AsyncMock()
|
||||||
|
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||||
|
|
||||||
|
chunks = ["chunk1", "chunk2"]
|
||||||
|
await milvus_vec_adapter.insert_chunks("db1", chunks)
|
||||||
|
|
||||||
|
fake_index.insert_chunks.assert_awaited_once_with(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_chunks_missing_db_raises(milvus_vec_adapter):
|
||||||
|
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await milvus_vec_adapter.insert_chunks("db_not_exist", [])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_calls_underlying_index_and_returns(milvus_vec_adapter):
|
||||||
|
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
||||||
|
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
||||||
|
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||||
|
|
||||||
|
response = await milvus_vec_adapter.query_chunks("db1", "my_query", {"param": 1})
|
||||||
|
|
||||||
|
fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1})
|
||||||
|
assert response is expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_missing_db_raises(milvus_vec_adapter):
|
||||||
|
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await milvus_vec_adapter.query_chunks("db_missing", "q", None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_openai_vector_store(milvus_vec_adapter):
|
||||||
|
store_id = "vs_1234"
|
||||||
|
openai_vector_store = {
|
||||||
|
"id": store_id,
|
||||||
|
"name": "Test Store",
|
||||||
|
"description": "A test OpenAI vector store",
|
||||||
|
"vector_db_id": "test_db",
|
||||||
|
"embedding_model": "test_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||||
|
|
||||||
|
assert openai_vector_store["id"] in milvus_vec_adapter.openai_vector_stores
|
||||||
|
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_openai_vector_store(milvus_vec_adapter):
|
||||||
|
store_id = "vs_1234"
|
||||||
|
openai_vector_store = {
|
||||||
|
"id": store_id,
|
||||||
|
"name": "Test Store",
|
||||||
|
"description": "A test OpenAI vector store",
|
||||||
|
"vector_db_id": "test_db",
|
||||||
|
"embedding_model": "test_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||||
|
openai_vector_store["description"] = "Updated description"
|
||||||
|
await milvus_vec_adapter._update_openai_vector_store(store_id, openai_vector_store)
|
||||||
|
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_openai_vector_store(milvus_vec_adapter):
|
||||||
|
store_id = "vs_1234"
|
||||||
|
openai_vector_store = {
|
||||||
|
"id": store_id,
|
||||||
|
"name": "Test Store",
|
||||||
|
"description": "A test OpenAI vector store",
|
||||||
|
"vector_db_id": "test_db",
|
||||||
|
"embedding_model": "test_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||||
|
await milvus_vec_adapter._delete_openai_vector_store_from_storage(store_id)
|
||||||
|
assert openai_vector_store["id"] not in milvus_vec_adapter.openai_vector_stores
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_openai_vector_stores(milvus_vec_adapter):
|
||||||
|
store_id = "vs_1234"
|
||||||
|
openai_vector_store = {
|
||||||
|
"id": store_id,
|
||||||
|
"name": "Test Store",
|
||||||
|
"description": "A test OpenAI vector store",
|
||||||
|
"vector_db_id": "test_db",
|
||||||
|
"embedding_model": "test_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||||
|
loaded_stores = await milvus_vec_adapter._load_openai_vector_stores()
|
||||||
|
assert loaded_stores[store_id] == openai_vector_store
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
|
||||||
|
store_id = "vs_1234"
|
||||||
|
file_id = "file_1234"
|
||||||
|
|
||||||
|
file_info = {
|
||||||
|
"id": file_id,
|
||||||
|
"status": "completed",
|
||||||
|
"vector_store_id": store_id,
|
||||||
|
"attributes": {},
|
||||||
|
"filename": "test_file.txt",
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
}
|
||||||
|
|
||||||
|
file_contents = [
|
||||||
|
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||||
|
]
|
||||||
|
|
||||||
|
# validating we don't raise an exception
|
||||||
|
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
|
||||||
|
store_id = "vs_1234"
|
||||||
|
file_id = "file_1234"
|
||||||
|
|
||||||
|
file_info = {
|
||||||
|
"id": file_id,
|
||||||
|
"status": "completed",
|
||||||
|
"vector_store_id": store_id,
|
||||||
|
"attributes": {},
|
||||||
|
"filename": "test_file.txt",
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
}
|
||||||
|
|
||||||
|
file_contents = [
|
||||||
|
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||||
|
]
|
||||||
|
|
||||||
|
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||||
|
|
||||||
|
updated_file_info = file_info.copy()
|
||||||
|
updated_file_info["filename"] = "updated_test_file.txt"
|
||||||
|
|
||||||
|
await milvus_vec_adapter._update_openai_vector_store_file(
|
||||||
|
store_id,
|
||||||
|
file_id,
|
||||||
|
updated_file_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file(store_id, file_id)
|
||||||
|
assert loaded_contents == updated_file_info
|
||||||
|
assert loaded_contents != file_info
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_path_factory):
|
||||||
|
store_id = "vs_1234"
|
||||||
|
file_id = "file_1234"
|
||||||
|
|
||||||
|
file_info = {
|
||||||
|
"id": file_id,
|
||||||
|
"status": "completed",
|
||||||
|
"vector_store_id": store_id,
|
||||||
|
"attributes": {},
|
||||||
|
"filename": "test_file.txt",
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
}
|
||||||
|
|
||||||
|
file_contents = [
|
||||||
|
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||||
|
]
|
||||||
|
|
||||||
|
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||||
|
|
||||||
|
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||||
|
assert loaded_contents == file_contents
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, tmp_path_factory):
|
||||||
|
store_id = "vs_1234"
|
||||||
|
file_id = "file_1234"
|
||||||
|
|
||||||
|
file_info = {
|
||||||
|
"id": file_id,
|
||||||
|
"status": "completed",
|
||||||
|
"vector_store_id": store_id,
|
||||||
|
"attributes": {},
|
||||||
|
"filename": "test_file.txt",
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
}
|
||||||
|
|
||||||
|
file_contents = [
|
||||||
|
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||||
|
]
|
||||||
|
|
||||||
|
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||||
|
await milvus_vec_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
|
||||||
|
|
||||||
|
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||||
|
assert loaded_contents == []
|
|
@ -11,10 +11,16 @@ import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AuthenticationConfig,
|
||||||
|
AuthProviderType,
|
||||||
|
CustomAuthConfig,
|
||||||
|
OAuth2IntrospectionConfig,
|
||||||
|
OAuth2JWKSConfig,
|
||||||
|
OAuth2TokenAuthConfig,
|
||||||
|
)
|
||||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||||
from llama_stack.distribution.server.auth_providers import (
|
from llama_stack.distribution.server.auth_providers import (
|
||||||
AuthProviderType,
|
|
||||||
get_attributes_from_claims,
|
get_attributes_from_claims,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -61,24 +67,11 @@ def invalid_token():
|
||||||
def http_app(mock_auth_endpoint):
|
def http_app(mock_auth_endpoint):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = AuthenticationConfig(
|
||||||
provider_type=AuthProviderType.CUSTOM,
|
provider_config=CustomAuthConfig(
|
||||||
config={"endpoint": mock_auth_endpoint},
|
type=AuthProviderType.CUSTOM,
|
||||||
)
|
endpoint=mock_auth_endpoint,
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
),
|
||||||
|
access_policy=[],
|
||||||
@app.get("/test")
|
|
||||||
def test_endpoint():
|
|
||||||
return {"message": "Authentication successful"}
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def k8s_app():
|
|
||||||
app = FastAPI()
|
|
||||||
auth_config = AuthenticationConfig(
|
|
||||||
provider_type=AuthProviderType.KUBERNETES,
|
|
||||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@ -94,11 +87,6 @@ def http_client(http_app):
|
||||||
return TestClient(http_app)
|
return TestClient(http_app)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def k8s_client(k8s_app):
|
|
||||||
return TestClient(k8s_app)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_scope():
|
def mock_scope():
|
||||||
return {
|
return {
|
||||||
|
@ -117,18 +105,11 @@ def mock_scope():
|
||||||
def mock_http_middleware(mock_auth_endpoint):
|
def mock_http_middleware(mock_auth_endpoint):
|
||||||
mock_app = AsyncMock()
|
mock_app = AsyncMock()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = AuthenticationConfig(
|
||||||
provider_type=AuthProviderType.CUSTOM,
|
provider_config=CustomAuthConfig(
|
||||||
config={"endpoint": mock_auth_endpoint},
|
type=AuthProviderType.CUSTOM,
|
||||||
)
|
endpoint=mock_auth_endpoint,
|
||||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
),
|
||||||
|
access_policy=[],
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_k8s_middleware():
|
|
||||||
mock_app = AsyncMock()
|
|
||||||
auth_config = AuthenticationConfig(
|
|
||||||
provider_type=AuthProviderType.KUBERNETES,
|
|
||||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
|
||||||
)
|
)
|
||||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||||
|
|
||||||
|
@ -161,13 +142,14 @@ async def mock_post_exception(*args, **kwargs):
|
||||||
def test_missing_auth_header(http_client):
|
def test_missing_auth_header(http_client):
|
||||||
response = http_client.get("/test")
|
response = http_client.get("/test")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
assert "Authentication required" in response.json()["error"]["message"]
|
||||||
|
assert "validated by mock-auth-service" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_auth_header_format(http_client):
|
def test_invalid_auth_header_format(http_client):
|
||||||
response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
@patch("httpx.AsyncClient.post", new=mock_post_success)
|
@patch("httpx.AsyncClient.post", new=mock_post_success)
|
||||||
|
@ -262,14 +244,14 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
||||||
def oauth2_app():
|
def oauth2_app():
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = AuthenticationConfig(
|
||||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
provider_config=OAuth2TokenAuthConfig(
|
||||||
config={
|
type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
"jwks": {
|
jwks=OAuth2JWKSConfig(
|
||||||
"uri": "http://mock-authz-service/token/introspect",
|
uri="http://mock-authz-service/token/introspect",
|
||||||
"key_recheck_period": "3600",
|
),
|
||||||
},
|
audience="llama-stack",
|
||||||
"audience": "llama-stack",
|
),
|
||||||
},
|
access_policy=[],
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@ -288,13 +270,14 @@ def oauth2_client(oauth2_app):
|
||||||
def test_missing_auth_header_oauth2(oauth2_client):
|
def test_missing_auth_header_oauth2(oauth2_client):
|
||||||
response = oauth2_client.get("/test")
|
response = oauth2_client.get("/test")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
assert "Authentication required" in response.json()["error"]["message"]
|
||||||
|
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||||
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
async def mock_jwks_response(*args, **kwargs):
|
async def mock_jwks_response(*args, **kwargs):
|
||||||
|
@ -358,15 +341,16 @@ async def mock_auth_jwks_response(*args, **kwargs):
|
||||||
def oauth2_app_with_jwks_token():
|
def oauth2_app_with_jwks_token():
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = AuthenticationConfig(
|
||||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
provider_config=OAuth2TokenAuthConfig(
|
||||||
config={
|
type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
"jwks": {
|
jwks=OAuth2JWKSConfig(
|
||||||
"uri": "http://mock-authz-service/token/introspect",
|
uri="http://mock-authz-service/token/introspect",
|
||||||
"key_recheck_period": "3600",
|
key_recheck_period=3600,
|
||||||
"token": "my-jwks-token",
|
token="my-jwks-token",
|
||||||
},
|
),
|
||||||
"audience": "llama-stack",
|
audience="llama-stack",
|
||||||
},
|
),
|
||||||
|
access_policy=[],
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@ -449,11 +433,15 @@ def mock_introspection_endpoint():
|
||||||
def introspection_app(mock_introspection_endpoint):
|
def introspection_app(mock_introspection_endpoint):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = AuthenticationConfig(
|
||||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
provider_config=OAuth2TokenAuthConfig(
|
||||||
config={
|
type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
"jwks": None,
|
introspection=OAuth2IntrospectionConfig(
|
||||||
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
url=mock_introspection_endpoint,
|
||||||
},
|
client_id="myclient",
|
||||||
|
client_secret="abcdefg",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
access_policy=[],
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@ -468,22 +456,22 @@ def introspection_app(mock_introspection_endpoint):
|
||||||
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
auth_config = AuthenticationConfig(
|
auth_config = AuthenticationConfig(
|
||||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
provider_config=OAuth2TokenAuthConfig(
|
||||||
config={
|
type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
"jwks": None,
|
introspection=OAuth2IntrospectionConfig(
|
||||||
"introspection": {
|
url=mock_introspection_endpoint,
|
||||||
"url": mock_introspection_endpoint,
|
client_id="myclient",
|
||||||
"client_id": "myclient",
|
client_secret="abcdefg",
|
||||||
"client_secret": "abcdefg",
|
send_secret_in_body=True,
|
||||||
"send_secret_in_body": "true",
|
),
|
||||||
},
|
claims_mapping={
|
||||||
"claims_mapping": {
|
|
||||||
"sub": "roles",
|
"sub": "roles",
|
||||||
"scope": "roles",
|
"scope": "roles",
|
||||||
"groups": "teams",
|
"groups": "teams",
|
||||||
"aud": "namespaces",
|
"aud": "namespaces",
|
||||||
},
|
},
|
||||||
},
|
),
|
||||||
|
access_policy=[],
|
||||||
)
|
)
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@ -507,13 +495,14 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
|
||||||
def test_missing_auth_header_introspection(introspection_client):
|
def test_missing_auth_header_introspection(introspection_client):
|
||||||
response = introspection_client.get("/test")
|
response = introspection_client.get("/test")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
assert "Authentication required" in response.json()["error"]["message"]
|
||||||
|
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_auth_header_format_introspection(introspection_client):
|
def test_invalid_auth_header_format_introspection(introspection_client):
|
||||||
response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
async def mock_introspection_active(*args, **kwargs):
|
async def mock_introspection_active(*args, **kwargs):
|
||||||
|
|
200
tests/unit/server/test_auth_github.py
Normal file
200
tests/unit/server/test_auth_github.py
Normal file
|
@ -0,0 +1,200 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig
|
||||||
|
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status_code, json_data):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
if self.status_code != 200:
|
||||||
|
# Create a mock request for the HTTPStatusError
|
||||||
|
mock_request = httpx.Request("GET", "https://api.github.com/user")
|
||||||
|
raise httpx.HTTPStatusError(f"HTTP error: {self.status_code}", request=mock_request, response=self)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def github_token_app():
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Configure GitHub token auth
|
||||||
|
auth_config = AuthenticationConfig(
|
||||||
|
provider_config=GitHubTokenAuthConfig(
|
||||||
|
type=AuthProviderType.GITHUB_TOKEN,
|
||||||
|
github_api_base_url="https://api.github.com",
|
||||||
|
claims_mapping={
|
||||||
|
"login": "username",
|
||||||
|
"id": "user_id",
|
||||||
|
"organizations": "teams",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
access_policy=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add auth middleware
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def github_token_client(github_token_app):
|
||||||
|
return TestClient(github_token_app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_authenticated_endpoint_without_token(github_token_client):
|
||||||
|
"""Test accessing protected endpoint without token"""
|
||||||
|
response = github_token_client.get("/test")
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Authentication required" in response.json()["error"]["message"]
|
||||||
|
assert "GitHub access token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_authenticated_endpoint_with_invalid_bearer_format(github_token_client):
|
||||||
|
"""Test accessing protected endpoint with invalid bearer format"""
|
||||||
|
response = github_token_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||||
|
def test_authenticated_endpoint_with_valid_github_token(mock_client_class, github_token_client):
|
||||||
|
"""Test accessing protected endpoint with valid GitHub token"""
|
||||||
|
# Mock the GitHub API responses
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||||
|
|
||||||
|
# Mock successful user API response
|
||||||
|
mock_client.get.side_effect = [
|
||||||
|
MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"login": "testuser",
|
||||||
|
"id": 12345,
|
||||||
|
"email": "test@example.com",
|
||||||
|
"name": "Test User",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MockResponse(
|
||||||
|
200,
|
||||||
|
[
|
||||||
|
{"login": "test-org-1"},
|
||||||
|
{"login": "test-org-2"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = github_token_client.get("/test", headers={"Authorization": "Bearer github_token_123"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["message"] == "Authentication successful"
|
||||||
|
|
||||||
|
# Verify the GitHub API was called correctly
|
||||||
|
assert mock_client.get.call_count == 1
|
||||||
|
calls = mock_client.get.call_args_list
|
||||||
|
assert calls[0][0][0] == "https://api.github.com/user"
|
||||||
|
|
||||||
|
# Check authorization header was passed
|
||||||
|
assert calls[0][1]["headers"]["Authorization"] == "Bearer github_token_123"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||||
|
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
|
||||||
|
"""Test accessing protected endpoint with invalid GitHub token"""
|
||||||
|
# Mock the GitHub API to return 401 Unauthorized
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||||
|
|
||||||
|
# Mock failed user API response
|
||||||
|
mock_client.get.return_value = MockResponse(401, {"message": "Bad credentials"})
|
||||||
|
|
||||||
|
response = github_token_client.get("/test", headers={"Authorization": "Bearer invalid_token"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert (
|
||||||
|
"GitHub token validation failed. Please check your token and try again." in response.json()["error"]["message"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||||
|
def test_github_enterprise_support(mock_client_class):
|
||||||
|
"""Test GitHub Enterprise support with custom API base URL"""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Configure GitHub token auth with enterprise URL
|
||||||
|
auth_config = AuthenticationConfig(
|
||||||
|
provider_config=GitHubTokenAuthConfig(
|
||||||
|
type=AuthProviderType.GITHUB_TOKEN,
|
||||||
|
github_api_base_url="https://github.enterprise.com/api/v3",
|
||||||
|
),
|
||||||
|
access_policy=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Mock the GitHub Enterprise API responses
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||||
|
|
||||||
|
# Mock successful user API response
|
||||||
|
mock_client.get.side_effect = [
|
||||||
|
MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"login": "enterprise_user",
|
||||||
|
"id": 99999,
|
||||||
|
"email": "user@enterprise.com",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MockResponse(
|
||||||
|
200,
|
||||||
|
[
|
||||||
|
{"login": "enterprise-org"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = client.get("/test", headers={"Authorization": "Bearer enterprise_token"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify the correct GitHub Enterprise URLs were called
|
||||||
|
assert mock_client.get.call_count == 1
|
||||||
|
calls = mock_client.get.call_args_list
|
||||||
|
assert calls[0][0][0] == "https://github.enterprise.com/api/v3/user"
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_token_auth_error_message_format(github_token_client):
|
||||||
|
"""Test that the error message for missing auth is properly formatted"""
|
||||||
|
response = github_token_client.get("/test")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
error_data = response.json()
|
||||||
|
assert "error" in error_data
|
||||||
|
assert "message" in error_data["error"]
|
||||||
|
assert "Authentication required" in error_data["error"]["message"]
|
||||||
|
assert "https://docs.github.com" in error_data["error"]["message"] # Contains link to GitHub docs
|
|
@ -104,19 +104,17 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
|
|
||||||
# Test scenarios with different access control patterns
|
# Test scenarios with different access control patterns
|
||||||
test_scenarios = [
|
test_scenarios = [
|
||||||
# Scenario 1: Public record (no access control)
|
# Scenario 1: Public record (no access control - represents None user insert)
|
||||||
{"id": "1", "name": "public", "access_attributes": None},
|
{"id": "1", "name": "public", "access_attributes": None},
|
||||||
# Scenario 2: Empty access control (should be treated as public)
|
# Scenario 2: Record with roles requirement
|
||||||
{"id": "2", "name": "empty", "access_attributes": {}},
|
{"id": "2", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
|
||||||
# Scenario 3: Record with roles requirement
|
# Scenario 3: Record with multiple attribute categories
|
||||||
{"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
|
{"id": "3", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
|
||||||
# Scenario 4: Record with multiple attribute categories
|
# Scenario 4: Record with teams only (missing roles category)
|
||||||
{"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
|
{"id": "4", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
|
||||||
# Scenario 5: Record with teams only (missing roles category)
|
# Scenario 5: Record with roles and projects
|
||||||
{"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
|
|
||||||
# Scenario 6: Record with roles and projects
|
|
||||||
{
|
{
|
||||||
"id": "6",
|
"id": "5",
|
||||||
"name": "admin-project-x",
|
"name": "admin-project-x",
|
||||||
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
|
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
|
||||||
},
|
},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue