mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Merge branch 'main' into vectordb_name
This commit is contained in:
commit
b103ee9eb8
54 changed files with 4277 additions and 1773 deletions
9
.github/workflows/integration-auth-tests.yml
vendored
9
.github/workflows/integration-auth-tests.yml
vendored
|
@ -73,9 +73,12 @@ jobs:
|
|||
server:
|
||||
port: 8321
|
||||
EOF
|
||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.type = "${{ matrix.auth-provider }}"' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.tls_cafile = "${{ env.KUBERNETES_CA_CERT_PATH }}"' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.issuer = "${{ env.KUBERNETES_ISSUER }}"' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.audience = "${{ env.KUBERNETES_AUDIENCE }}"' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.jwks.uri = "${{ env.KUBERNETES_API_SERVER_URL }}"' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.provider_config.jwks.token = "${{ env.TOKEN }}"' -i $run_dir/run.yaml
|
||||
cat $run_dir/run.yaml
|
||||
|
||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||
|
|
70
.github/workflows/integration-sql-store-tests.yml
vendored
Normal file
70
.github/workflows/integration-sql-store-tests.yml
vendored
Normal file
|
@ -0,0 +1,70 @@
|
|||
name: SqlStore Integration Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'llama_stack/providers/utils/sqlstore/**'
|
||||
- 'tests/integration/sqlstore/**'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test-postgres:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12", "3.13"]
|
||||
fail-fast: false
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15
|
||||
env:
|
||||
POSTGRES_USER: llamastack
|
||||
POSTGRES_PASSWORD: llamastack
|
||||
POSTGRES_DB: llamastack
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install dependencies
|
||||
uses: ./.github/actions/setup-runner
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Run SqlStore Integration Tests
|
||||
env:
|
||||
ENABLE_POSTGRES_TESTS: "true"
|
||||
POSTGRES_HOST: localhost
|
||||
POSTGRES_PORT: 5432
|
||||
POSTGRES_DB: llamastack
|
||||
POSTGRES_USER: llamastack
|
||||
POSTGRES_PASSWORD: llamastack
|
||||
run: |
|
||||
uv run pytest -sv tests/integration/providers/utils/sqlstore/
|
||||
|
||||
- name: Upload test logs
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: postgres-test-logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.python-version }}
|
||||
path: |
|
||||
*.log
|
||||
retention-days: 1
|
|
@ -56,8 +56,8 @@ shields: []
|
|||
server:
|
||||
port: 8321
|
||||
auth:
|
||||
provider_type: "oauth2_token"
|
||||
config:
|
||||
provider_config:
|
||||
type: "oauth2_token"
|
||||
jwks:
|
||||
uri: "https://my-token-issuing-svc.com/jwks"
|
||||
```
|
||||
|
@ -226,6 +226,8 @@ server:
|
|||
|
||||
### Authentication Configuration
|
||||
|
||||
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
||||
|
||||
The `auth` section configures authentication for the server. When configured, all API requests must include a valid Bearer token in the Authorization header:
|
||||
|
||||
```
|
||||
|
@ -240,8 +242,8 @@ The server can be configured to use service account tokens for authorization, va
|
|||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_type: "oauth2_token"
|
||||
config:
|
||||
provider_config:
|
||||
type: "oauth2_token"
|
||||
jwks:
|
||||
uri: "https://kubernetes.default.svc:8443/openid/v1/jwks"
|
||||
token: "${env.TOKEN:+}"
|
||||
|
@ -325,13 +327,25 @@ You can easily validate a request by running:
|
|||
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
|
||||
```
|
||||
|
||||
#### GitHub Token Provider
|
||||
Validates GitHub personal access tokens or OAuth tokens directly:
|
||||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_config:
|
||||
type: "github_token"
|
||||
github_api_base_url: "https://api.github.com" # Or GitHub Enterprise URL
|
||||
```
|
||||
|
||||
The provider fetches user information from GitHub and maps it to access attributes based on the `claims_mapping` configuration.
|
||||
|
||||
#### Custom Provider
|
||||
Validates tokens against a custom authentication endpoint:
|
||||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_type: "custom"
|
||||
config:
|
||||
provider_config:
|
||||
type: "custom"
|
||||
endpoint: "https://auth.example.com/validate" # URL of the auth endpoint
|
||||
```
|
||||
|
||||
|
@ -416,8 +430,8 @@ clients.
|
|||
server:
|
||||
port: 8321
|
||||
auth:
|
||||
provider_type: custom
|
||||
config:
|
||||
provider_config:
|
||||
type: custom
|
||||
endpoint: https://auth.example.com/validate
|
||||
quota:
|
||||
kvstore:
|
||||
|
|
|
@ -39,6 +39,13 @@ docker pull llama-stack/distribution-meta-reference-gpu
|
|||
|
||||
**Guides:** [Meta Reference GPU Guide](self_hosted_distro/meta-reference-gpu)
|
||||
|
||||
### 🖥️ Self-Hosted with NVIDA NeMo Microservices
|
||||
|
||||
**Use `nvidia` if you:**
|
||||
- Want to use Llama Stack with NVIDIA NeMo Microservices
|
||||
|
||||
**Guides:** [NVIDIA Distribution Guide](self_hosted_distro/nvidia)
|
||||
|
||||
### ☁️ Managed Hosting
|
||||
|
||||
**Use remote-hosted endpoints if you:**
|
||||
|
|
177
docs/source/distributions/self_hosted_distro/nvidia.md
Normal file
177
docs/source/distributions/self_hosted_distro/nvidia.md
Normal file
|
@ -0,0 +1,177 @@
|
|||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||
# NVIDIA Distribution
|
||||
|
||||
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
|
||||
|
||||
| API | Provider(s) |
|
||||
|-----|-------------|
|
||||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `inline::localfs`, `remote::nvidia` |
|
||||
| eval | `remote::nvidia` |
|
||||
| inference | `remote::nvidia` |
|
||||
| post_training | `remote::nvidia` |
|
||||
| safety | `remote::nvidia` |
|
||||
| scoring | `inline::basic` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `inline::rag-runtime` |
|
||||
| vector_io | `inline::faiss` |
|
||||
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||
- `NVIDIA_GUARDRAILS_CONFIG_ID`: NVIDIA Guardrail Configuration ID (default: `self-check`)
|
||||
- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`)
|
||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
|
||||
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||
- `meta/llama-3.3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||
- `nvidia/nv-embedqa-e5-v5 `
|
||||
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||
- `snowflake/arctic-embed-l `
|
||||
|
||||
|
||||
## Prerequisites
|
||||
### NVIDIA API Keys
|
||||
|
||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
|
||||
|
||||
### Deploy NeMo Microservices Platform
|
||||
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
|
||||
|
||||
## Supported Services
|
||||
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
|
||||
|
||||
### Inference: NVIDIA NIM
|
||||
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
|
||||
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
|
||||
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
|
||||
|
||||
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
|
||||
|
||||
### Datasetio API: NeMo Data Store
|
||||
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
|
||||
|
||||
### Eval API: NeMo Evaluator
|
||||
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
|
||||
|
||||
### Post-Training API: NeMo Customizer
|
||||
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
|
||||
|
||||
### Safety API: NeMo Guardrails
|
||||
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
|
||||
|
||||
## Deploying models
|
||||
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||
|
||||
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
|
||||
```sh
|
||||
# URL to NeMo NIM Proxy service
|
||||
export NEMO_URL="http://nemo.test"
|
||||
|
||||
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"name": "llama-3.2-1b-instruct",
|
||||
"namespace": "meta",
|
||||
"config": {
|
||||
"model": "meta/llama-3.2-1b-instruct",
|
||||
"nim_deployment": {
|
||||
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
|
||||
"image_tag": "1.8.3",
|
||||
"pvc_size": "25Gi",
|
||||
"gpu": 1,
|
||||
"additional_envs": {
|
||||
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
|
||||
|
||||
You can also remove a deployed NIM to free up GPU resources, if needed.
|
||||
```sh
|
||||
export NEMO_URL="http://nemo.test"
|
||||
|
||||
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
|
||||
```
|
||||
|
||||
## Running Llama Stack with NVIDIA
|
||||
|
||||
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||
|
||||
### Via Docker
|
||||
|
||||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-nvidia \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||
llama stack build --template nvidia --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||
llama stack build --template nvidia --image-type venv
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
## Example Notebooks
|
||||
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.
|
|
@ -12,6 +12,7 @@ Please refer to the remote provider documentation.
|
|||
|-------|------|----------|---------|-------------|
|
||||
| `db_path` | `<class 'str'>` | No | PydanticUndefined | |
|
||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
|
||||
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -403,15 +403,16 @@ def _run_stack_build_command_from_build_config(
|
|||
if template_name:
|
||||
# copy run.yaml from template to build_dir instead of generating it again
|
||||
template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
|
||||
with importlib.resources.as_file(template_path) as path:
|
||||
run_config_file = build_dir / f"{template_name}-run.yaml"
|
||||
|
||||
with importlib.resources.as_file(template_path) as path:
|
||||
shutil.copy(path, run_config_file)
|
||||
|
||||
cprint("Build Successful!", color="green", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built template here: {template_path}", color="blue", file=sys.stderr)
|
||||
cprint(f"You can find the newly-built template here: {run_config_file}", color="blue", file=sys.stderr)
|
||||
cprint(
|
||||
"You can run the new Llama Stack distro via: "
|
||||
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "blue"),
|
||||
+ colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
|
||||
color="green",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
|
|
@ -155,7 +155,11 @@ class StackRun(Subcommand):
|
|||
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
|
||||
if callable(getattr(args, arg)):
|
||||
continue
|
||||
if arg == "config" and template_name:
|
||||
if arg == "config":
|
||||
if template_name:
|
||||
server_args.template = str(template_name)
|
||||
else:
|
||||
# Set the config file path
|
||||
server_args.config = str(config_file)
|
||||
else:
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
|
@ -168,6 +172,9 @@ class StackRun(Subcommand):
|
|||
run_args.extend([str(args.port)])
|
||||
|
||||
if config_file:
|
||||
if template_name:
|
||||
run_args.extend(["--template", str(template_name)])
|
||||
else:
|
||||
run_args.extend(["--config", str(config_file)])
|
||||
|
||||
if args.env:
|
||||
|
|
|
@ -6,9 +6,9 @@
|
|||
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -161,23 +161,113 @@ class LoggingConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||
|
||||
|
||||
class OAuth2IntrospectionConfig(BaseModel):
|
||||
url: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
send_secret_in_body: bool = False
|
||||
|
||||
|
||||
class AuthProviderType(StrEnum):
|
||||
"""Supported authentication provider types."""
|
||||
|
||||
OAUTH2_TOKEN = "oauth2_token"
|
||||
GITHUB_TOKEN = "github_token"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class OAuth2TokenAuthConfig(BaseModel):
|
||||
"""Configuration for OAuth2 token authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
|
||||
audience: str = Field(default="llama-stack")
|
||||
verify_tls: bool = Field(default=True)
|
||||
tls_cafile: Path | None = Field(default=None)
|
||||
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles",
|
||||
"username": "roles",
|
||||
"groups": "teams",
|
||||
"team": "teams",
|
||||
"project": "projects",
|
||||
"tenant": "namespaces",
|
||||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
|
||||
introspection: OAuth2IntrospectionConfig | None = Field(
|
||||
default=None, description="OAuth2 introspection configuration"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode(self) -> Self:
|
||||
if not self.jwks and not self.introspection:
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
if self.jwks and self.introspection:
|
||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||
return self
|
||||
|
||||
|
||||
class CustomAuthConfig(BaseModel):
|
||||
"""Configuration for custom authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
|
||||
endpoint: str = Field(
|
||||
...,
|
||||
description="Custom authentication endpoint URL",
|
||||
)
|
||||
|
||||
|
||||
class GitHubTokenAuthConfig(BaseModel):
|
||||
"""Configuration for GitHub token authentication."""
|
||||
|
||||
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
|
||||
github_api_base_url: str = Field(
|
||||
default="https://api.github.com",
|
||||
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
|
||||
)
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"login": "roles",
|
||||
"organizations": "teams",
|
||||
},
|
||||
description="Mapping from GitHub user fields to access attributes",
|
||||
)
|
||||
|
||||
|
||||
AuthProviderConfig = Annotated[
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
provider_type: AuthProviderType = Field(
|
||||
"""Top-level authentication configuration."""
|
||||
|
||||
provider_config: AuthProviderConfig = Field(
|
||||
...,
|
||||
description="Type of authentication provider",
|
||||
description="Authentication provider configuration",
|
||||
)
|
||||
config: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Provider-specific configuration",
|
||||
access_policy: list[AccessRule] = Field(
|
||||
default=[],
|
||||
description="Rules for determining access to resources",
|
||||
)
|
||||
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
|
||||
|
||||
|
||||
class AuthenticationRequiredError(Exception):
|
||||
|
|
|
@ -87,8 +87,12 @@ class AuthenticationMiddleware:
|
|||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||
if not auth_header:
|
||||
error_msg = self.auth_provider.get_auth_error_message(scope)
|
||||
return await self._send_auth_error(send, error_msg)
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Invalid Authorization header format")
|
||||
|
||||
token = auth_header.split("Bearer ", 1)[1]
|
||||
|
||||
|
|
|
@ -8,15 +8,19 @@ import ssl
|
|||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubTokenAuthConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
@ -38,9 +42,7 @@ class AuthRequestContext(BaseModel):
|
|||
|
||||
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: dict[str, list[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
|
@ -62,6 +64,10 @@ class AuthProvider(ABC):
|
|||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return provider-specific authentication error message."""
|
||||
return "Authentication required"
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||
attributes: dict[str, list[str]] = {}
|
||||
|
@ -81,56 +87,6 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
|||
return attributes
|
||||
|
||||
|
||||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||
|
||||
|
||||
class OAuth2IntrospectionConfig(BaseModel):
|
||||
url: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
send_secret_in_body: bool = False
|
||||
|
||||
|
||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||
audience: str = "llama-stack"
|
||||
verify_tls: bool = True
|
||||
tls_cafile: Path | None = None
|
||||
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles",
|
||||
"username": "roles",
|
||||
"groups": "teams",
|
||||
"team": "teams",
|
||||
"project": "projects",
|
||||
"tenant": "namespaces",
|
||||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
jwks: OAuth2JWKSConfig | None
|
||||
introspection: OAuth2IntrospectionConfig | None = None
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode(self) -> Self:
|
||||
if not self.jwks and not self.introspection:
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
if self.jwks and self.introspection:
|
||||
raise ValueError("At present only one of jwks or introspection should be configured")
|
||||
return self
|
||||
|
||||
|
||||
class OAuth2TokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||
|
@ -138,7 +94,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
This should be the standard authentication provider for most use cases.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
||||
def __init__(self, config: OAuth2TokenAuthConfig):
|
||||
self.config = config
|
||||
self._jwks_at: float = 0.0
|
||||
self._jwks: dict[str, str] = {}
|
||||
|
@ -170,7 +126,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
issuer=self.config.issuer,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Invalid JWT token: {token}") from exc
|
||||
raise ValueError("Invalid JWT token") from exc
|
||||
|
||||
# There are other standard claims, the most relevant of which is `scope`.
|
||||
# We should incorporate these into the access attributes.
|
||||
|
@ -232,6 +188,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
async def close(self):
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return OAuth2-specific authentication error message."""
|
||||
if self.config.issuer:
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
|
||||
elif self.config.introspection:
|
||||
# Extract domain from introspection URL for a cleaner message
|
||||
domain = urlparse(self.config.introspection.url).netloc
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
|
||||
else:
|
||||
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
"""
|
||||
Refresh the JWKS cache.
|
||||
|
@ -264,14 +231,10 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
self._jwks_at = time.time()
|
||||
|
||||
|
||||
class CustomAuthProviderConfig(BaseModel):
|
||||
endpoint: str
|
||||
|
||||
|
||||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
||||
def __init__(self, config: CustomAuthProviderConfig):
|
||||
def __init__(self, config: CustomAuthConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
|
@ -317,7 +280,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
return User(auth_response.principal, auth_response.attributes)
|
||||
return User(principal=auth_response.principal, attributes=auth_response.attributes)
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
@ -338,15 +301,88 @@ class CustomAuthProvider(AuthProvider):
|
|||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return custom auth provider-specific authentication error message."""
|
||||
domain = urlparse(self.config.endpoint).netloc
|
||||
if domain:
|
||||
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
|
||||
else:
|
||||
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
|
||||
|
||||
|
||||
class GitHubTokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
GitHub token authentication provider that validates GitHub access tokens directly.
|
||||
|
||||
This provider accepts GitHub personal access tokens or OAuth tokens and verifies
|
||||
them against the GitHub API to get user information.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GitHubTokenAuthConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a GitHub token by calling the GitHub API.
|
||||
|
||||
This validates tokens issued by GitHub (personal access tokens or OAuth tokens).
|
||||
"""
|
||||
try:
|
||||
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning(f"GitHub token validation failed: {e}")
|
||||
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
|
||||
|
||||
principal = user_info["user"]["login"]
|
||||
|
||||
github_data = {
|
||||
"login": user_info["user"]["login"],
|
||||
"id": str(user_info["user"]["id"]),
|
||||
"organizations": user_info.get("organizations", []),
|
||||
}
|
||||
|
||||
access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
|
||||
|
||||
return User(
|
||||
principal=principal,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return GitHub-specific authentication error message."""
|
||||
return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
|
||||
|
||||
|
||||
async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict:
|
||||
"""Fetch user info and organizations from GitHub API."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "llama-stack",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0)
|
||||
user_response.raise_for_status()
|
||||
user_data = user_response.json()
|
||||
|
||||
return {
|
||||
"user": user_data,
|
||||
}
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_type = config.provider_type.lower()
|
||||
provider_config = config.provider_config
|
||||
|
||||
if provider_type == "custom":
|
||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "oauth2_token":
|
||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||
if isinstance(provider_config, CustomAuthConfig):
|
||||
return CustomAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, OAuth2TokenAuthConfig):
|
||||
return OAuth2TokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||
return GitHubTokenAuthProvider(provider_config)
|
||||
else:
|
||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
||||
|
|
|
@ -33,7 +33,11 @@ from pydantic import BaseModel, ValidationError
|
|||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
LoggingConfig,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
|
@ -217,7 +221,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
principal = request.scope.get("principal", "")
|
||||
user = User(principal, user_attributes)
|
||||
user = User(principal=principal, attributes=user_attributes)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
|
@ -405,13 +409,13 @@ def main(args: argparse.Namespace | None = None):
|
|||
args = parser.parse_args()
|
||||
|
||||
log_line = ""
|
||||
if args.config:
|
||||
if hasattr(args, "config") and args.config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
log_line = f"Using config file: {config_file}"
|
||||
elif args.template:
|
||||
elif hasattr(args, "template") and args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
|
@ -455,7 +459,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
|
||||
# Add authentication middleware if configured
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||
else:
|
||||
if config.server.quota:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
|
@ -19,6 +19,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
class MilvusVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
|
|
|
@ -154,10 +154,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.mdel_validate_json(vector_db_data)
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=await MilvusIndex(
|
||||
index=MilvusIndex(
|
||||
client=self.client,
|
||||
collection_name=vector_db.identifier,
|
||||
consistency_level=self.config.consistency_level,
|
||||
|
@ -259,6 +259,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.delete(key)
|
||||
if store_id in self.openai_vector_stores:
|
||||
del self.openai_vector_stores[store_id]
|
||||
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from persistent storage."""
|
||||
|
@ -377,6 +379,29 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
logger.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
|
||||
return {}
|
||||
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in Milvus database."""
|
||||
try:
|
||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||
return
|
||||
|
||||
file_data = [
|
||||
{
|
||||
"store_file_id": f"{store_id}_{file_id}",
|
||||
"store_id": store_id,
|
||||
"file_id": file_id,
|
||||
"file_info": json.dumps(file_info),
|
||||
}
|
||||
]
|
||||
await asyncio.to_thread(
|
||||
self.client.upsert,
|
||||
collection_name="openai_vector_store_files",
|
||||
data=file_data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
"""Load vector store file contents from Milvus database."""
|
||||
try:
|
||||
|
@ -405,29 +430,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
logger.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
|
||||
return []
|
||||
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in Milvus database."""
|
||||
try:
|
||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||
return
|
||||
|
||||
file_data = [
|
||||
{
|
||||
"store_file_id": f"{store_id}_{file_id}",
|
||||
"store_id": store_id,
|
||||
"file_id": file_id,
|
||||
"file_info": json.dumps(file_info),
|
||||
}
|
||||
]
|
||||
await asyncio.to_thread(
|
||||
self.client.upsert,
|
||||
collection_name="openai_vector_store_files",
|
||||
data=file_data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store file metadata from Milvus database."""
|
||||
try:
|
||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user
|
|||
from llama_stack.log import get_logger
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||
from .sqlstore import SqlStoreType
|
||||
|
||||
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
||||
|
||||
|
@ -71,9 +72,18 @@ class AuthorizedSqlStore:
|
|||
:param sql_store: Base SqlStore implementation to wrap
|
||||
"""
|
||||
self.sql_store = sql_store
|
||||
|
||||
self._detect_database_type()
|
||||
self._validate_sql_optimized_policy()
|
||||
|
||||
def _detect_database_type(self) -> None:
|
||||
"""Detect the database type from the underlying SQL store."""
|
||||
if not hasattr(self.sql_store, "config"):
|
||||
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
||||
|
||||
self.database_type = self.sql_store.config.type
|
||||
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _validate_sql_optimized_policy(self) -> None:
|
||||
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
||||
|
||||
|
@ -181,6 +191,50 @@ class AuthorizedSqlStore:
|
|||
else:
|
||||
return self._build_conservative_where_clause()
|
||||
|
||||
def _json_extract(self, column: str, path: str) -> str:
|
||||
"""Extract JSON value (keeping JSON type).
|
||||
|
||||
Args:
|
||||
column: The JSON column name
|
||||
path: The JSON path (e.g., 'roles', 'teams')
|
||||
|
||||
Returns:
|
||||
SQL expression to extract JSON value
|
||||
"""
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
return f"{column}->'{path}'"
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _json_extract_text(self, column: str, path: str) -> str:
|
||||
"""Extract JSON value as text.
|
||||
|
||||
Args:
|
||||
column: The JSON column name
|
||||
path: The JSON path (e.g., 'roles', 'teams')
|
||||
|
||||
Returns:
|
||||
SQL expression to extract JSON value as text
|
||||
"""
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
return f"{column}->>'{path}'"
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _get_public_access_conditions(self) -> list[str]:
|
||||
"""Get the SQL conditions for public access."""
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
# Postgres stores JSON null as 'null'
|
||||
return ["access_attributes::text = 'null'"]
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
return ["access_attributes = 'null'"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _build_default_policy_where_clause(self) -> str:
|
||||
"""Build SQL WHERE clause for the default policy.
|
||||
|
||||
|
@ -189,29 +243,32 @@ class AuthorizedSqlStore:
|
|||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
base_conditions = self._get_public_access_conditions()
|
||||
if not current_user or not current_user.attributes:
|
||||
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
|
||||
# Only allow public records
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
else:
|
||||
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
|
||||
|
||||
user_attr_conditions = []
|
||||
|
||||
for attr_key, user_values in current_user.attributes.items():
|
||||
if user_values:
|
||||
value_conditions = []
|
||||
for value in user_values:
|
||||
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
|
||||
# Check if JSON array contains the value
|
||||
escaped_value = value.replace("'", "''")
|
||||
json_text = self._json_extract_text("access_attributes", attr_key)
|
||||
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
|
||||
|
||||
if value_conditions:
|
||||
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
|
||||
# Check if the category is missing (NULL)
|
||||
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
|
||||
user_matches_category = f"({' OR '.join(value_conditions)})"
|
||||
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
||||
|
||||
if user_attr_conditions:
|
||||
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
||||
base_conditions.append(all_requirements_met)
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
else:
|
||||
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
|
||||
def _build_conservative_where_clause(self) -> str:
|
||||
|
@ -222,5 +279,8 @@ class AuthorizedSqlStore:
|
|||
current_user = get_authenticated_user()
|
||||
|
||||
if not current_user:
|
||||
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
|
||||
# Only allow public records
|
||||
base_conditions = self._get_public_access_conditions()
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
|
||||
return "1=1"
|
||||
|
|
|
@ -4,9 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
|
@ -19,7 +18,7 @@ from .api import SqlStore
|
|||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
|
||||
class SqlStoreType(Enum):
|
||||
class SqlStoreType(StrEnum):
|
||||
sqlite = "sqlite"
|
||||
postgres = "postgres"
|
||||
|
||||
|
@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel):
|
|||
|
||||
|
||||
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
||||
type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
||||
|
@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
|||
|
||||
|
||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal["postgres"] = SqlStoreType.postgres.value
|
||||
type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
db: str = "llamastack"
|
||||
|
@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
|
|||
|
||||
|
||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
|
||||
if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
impl = SqlAlchemySqlStoreImpl(config)
|
||||
|
|
7
llama_stack/templates/nvidia/__init__.py
Normal file
7
llama_stack/templates/nvidia/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .nvidia import get_distribution_template # noqa: F401
|
29
llama_stack/templates/nvidia/build.yaml
Normal file
29
llama_stack/templates/nvidia/build.yaml
Normal file
|
@ -0,0 +1,29 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
|
||||
providers:
|
||||
inference:
|
||||
- remote::nvidia
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
safety:
|
||||
- remote::nvidia
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
eval:
|
||||
- remote::nvidia
|
||||
post_training:
|
||||
- remote::nvidia
|
||||
datasetio:
|
||||
- inline::localfs
|
||||
- remote::nvidia
|
||||
scoring:
|
||||
- inline::basic
|
||||
tool_runtime:
|
||||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
149
llama_stack/templates/nvidia/doc_template.md
Normal file
149
llama_stack/templates/nvidia/doc_template.md
Normal file
|
@ -0,0 +1,149 @@
|
|||
# NVIDIA Distribution
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||
|
||||
{{ providers_table }}
|
||||
|
||||
{% if run_config_env_vars %}
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if default_models %}
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
{% for model in default_models %}
|
||||
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
|
||||
## Prerequisites
|
||||
### NVIDIA API Keys
|
||||
|
||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
|
||||
|
||||
### Deploy NeMo Microservices Platform
|
||||
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
|
||||
|
||||
## Supported Services
|
||||
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
|
||||
|
||||
### Inference: NVIDIA NIM
|
||||
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
|
||||
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
|
||||
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
|
||||
|
||||
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
|
||||
|
||||
### Datasetio API: NeMo Data Store
|
||||
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
|
||||
|
||||
### Eval API: NeMo Evaluator
|
||||
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
|
||||
|
||||
### Post-Training API: NeMo Customizer
|
||||
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
|
||||
|
||||
### Safety API: NeMo Guardrails
|
||||
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
|
||||
|
||||
## Deploying models
|
||||
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||
|
||||
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
|
||||
```sh
|
||||
# URL to NeMo NIM Proxy service
|
||||
export NEMO_URL="http://nemo.test"
|
||||
|
||||
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"name": "llama-3.2-1b-instruct",
|
||||
"namespace": "meta",
|
||||
"config": {
|
||||
"model": "meta/llama-3.2-1b-instruct",
|
||||
"nim_deployment": {
|
||||
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
|
||||
"image_tag": "1.8.3",
|
||||
"pvc_size": "25Gi",
|
||||
"gpu": 1,
|
||||
"additional_envs": {
|
||||
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
|
||||
|
||||
You can also remove a deployed NIM to free up GPU resources, if needed.
|
||||
```sh
|
||||
export NEMO_URL="http://nemo.test"
|
||||
|
||||
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
|
||||
```
|
||||
|
||||
## Running Llama Stack with NVIDIA
|
||||
|
||||
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||
|
||||
### Via Docker
|
||||
|
||||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||
llama stack build --template nvidia --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||
llama stack build --template nvidia --image-type venv
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
## Example Notebooks
|
||||
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.
|
150
llama_stack/templates/nvidia/nvidia.py
Normal file
150
llama_stack/templates/nvidia/nvidia.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
|
||||
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::nvidia"],
|
||||
"vector_io": ["inline::faiss"],
|
||||
"safety": ["remote::nvidia"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["remote::nvidia"],
|
||||
"post_training": ["remote::nvidia"],
|
||||
"datasetio": ["inline::localfs", "remote::nvidia"],
|
||||
"scoring": ["inline::basic"],
|
||||
"tool_runtime": ["inline::rag-runtime"],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
config=NVIDIAConfig.sample_run_config(),
|
||||
)
|
||||
safety_provider = Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
config=NVIDIASafetyConfig.sample_run_config(),
|
||||
)
|
||||
datasetio_provider = Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
config=NvidiaDatasetIOConfig.sample_run_config(),
|
||||
)
|
||||
eval_provider = Provider(
|
||||
provider_id="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
config=NVIDIAEvalConfig.sample_run_config(),
|
||||
)
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="nvidia",
|
||||
)
|
||||
safety_model = ModelInput(
|
||||
model_id="${env.SAFETY_MODEL}",
|
||||
provider_id="nvidia",
|
||||
)
|
||||
|
||||
available_models = {
|
||||
"nvidia": MODEL_ENTRIES,
|
||||
}
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
|
||||
default_models = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name="nvidia",
|
||||
distro_type="self_hosted",
|
||||
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"datasetio": [datasetio_provider],
|
||||
"eval": [eval_provider],
|
||||
},
|
||||
default_models=default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
"run-with-safety.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [
|
||||
inference_provider,
|
||||
safety_provider,
|
||||
],
|
||||
"eval": [eval_provider],
|
||||
},
|
||||
default_models=[inference_model, safety_model],
|
||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"NVIDIA_API_KEY": (
|
||||
"",
|
||||
"NVIDIA API Key",
|
||||
),
|
||||
"NVIDIA_APPEND_API_VERSION": (
|
||||
"True",
|
||||
"Whether to append the API version to the base_url",
|
||||
),
|
||||
## Nemo Customizer related variables
|
||||
"NVIDIA_DATASET_NAMESPACE": (
|
||||
"default",
|
||||
"NVIDIA Dataset Namespace",
|
||||
),
|
||||
"NVIDIA_PROJECT_ID": (
|
||||
"test-project",
|
||||
"NVIDIA Project ID",
|
||||
),
|
||||
"NVIDIA_CUSTOMIZER_URL": (
|
||||
"https://customizer.api.nvidia.com",
|
||||
"NVIDIA Customizer URL",
|
||||
),
|
||||
"NVIDIA_OUTPUT_MODEL_DIR": (
|
||||
"test-example-model@v1",
|
||||
"NVIDIA Output Model Directory",
|
||||
),
|
||||
"GUARDRAILS_SERVICE_URL": (
|
||||
"http://0.0.0.0:7331",
|
||||
"URL for the NeMo Guardrails Service",
|
||||
),
|
||||
"NVIDIA_GUARDRAILS_CONFIG_ID": (
|
||||
"self-check",
|
||||
"NVIDIA Guardrail Configuration ID",
|
||||
),
|
||||
"NVIDIA_EVALUATOR_URL": (
|
||||
"http://0.0.0.0:7331",
|
||||
"URL for the NeMo Evaluator Service",
|
||||
),
|
||||
"INFERENCE_MODEL": (
|
||||
"Llama3.1-8B-Instruct",
|
||||
"Inference model",
|
||||
),
|
||||
"SAFETY_MODEL": (
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
"Name of the model to use for safety",
|
||||
),
|
||||
},
|
||||
)
|
119
llama_stack/templates/nvidia/run-with-safety.yaml
Normal file
119
llama_stack/templates/nvidia/run-with-safety.yaml
Normal file
|
@ -0,0 +1,119 @@
|
|||
version: 2
|
||||
image_name: nvidia
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
|
||||
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
|
||||
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}
|
||||
post_training:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
|
||||
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
|
||||
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
|
||||
datasetio:
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/localfs_datasetio.db
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
|
||||
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
|
||||
datasets_url: ${env.NVIDIA_DATASETS_URL:=http://nemo.test}
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
tool_runtime:
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: nvidia
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.SAFETY_MODEL}
|
||||
provider_id: nvidia
|
||||
model_type: llm
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL}
|
||||
provider_id: nvidia
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
226
llama_stack/templates/nvidia/run.yaml
Normal file
226
llama_stack/templates/nvidia/run.yaml
Normal file
|
@ -0,0 +1,226 @@
|
|||
version: 2
|
||||
image_name: nvidia
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
|
||||
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}
|
||||
post_training:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
|
||||
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
|
||||
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
|
||||
datasetio:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
|
||||
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
|
||||
datasets_url: ${env.NVIDIA_DATASETS_URL:=http://nemo.test}
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
tool_runtime:
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta/llama3-8b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-8B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-8b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-405b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-1b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-3b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-11b-vision-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-90b-vision-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 2048
|
||||
context_length: 8192
|
||||
model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 1024
|
||||
context_length: 512
|
||||
model_id: nvidia/nv-embedqa-e5-v5
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/nv-embedqa-e5-v5
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 4096
|
||||
context_length: 512
|
||||
model_id: nvidia/nv-embedqa-mistral-7b-v2
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 1024
|
||||
context_length: 512
|
||||
model_id: snowflake/arctic-embed-l
|
||||
provider_id: nvidia
|
||||
provider_model_id: snowflake/arctic-embed-l
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
6
llama_stack/ui/app/api/auth/[...nextauth]/route.ts
Normal file
6
llama_stack/ui/app/api/auth/[...nextauth]/route.ts
Normal file
|
@ -0,0 +1,6 @@
|
|||
import NextAuth from "next-auth";
|
||||
import { authOptions } from "@/lib/auth";
|
||||
|
||||
const handler = NextAuth(authOptions);
|
||||
|
||||
export { handler as GET, handler as POST };
|
118
llama_stack/ui/app/auth/signin/page.tsx
Normal file
118
llama_stack/ui/app/auth/signin/page.tsx
Normal file
|
@ -0,0 +1,118 @@
|
|||
"use client";
|
||||
|
||||
import { signIn, signOut, useSession } from "next-auth/react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import { Copy, Check, Home, Github } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export default function SignInPage() {
|
||||
const { data: session, status } = useSession();
|
||||
const [copied, setCopied] = useState(false);
|
||||
const router = useRouter();
|
||||
|
||||
const handleCopyToken = async () => {
|
||||
if (session?.accessToken) {
|
||||
await navigator.clipboard.writeText(session.accessToken);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
}
|
||||
};
|
||||
|
||||
if (status === "loading") {
|
||||
return (
|
||||
<div className="flex items-center justify-center min-h-screen">
|
||||
<div className="text-muted-foreground">Loading...</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-center min-h-screen">
|
||||
<Card className="w-[400px]">
|
||||
<CardHeader>
|
||||
<CardTitle>Authentication</CardTitle>
|
||||
<CardDescription>
|
||||
{session
|
||||
? "You are successfully authenticated!"
|
||||
: "Sign in with GitHub to use your access token as an API key"}
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
{!session ? (
|
||||
<Button
|
||||
onClick={() => {
|
||||
console.log("Signing in with GitHub...");
|
||||
signIn("github", { callbackUrl: "/auth/signin" }).catch(
|
||||
(error) => {
|
||||
console.error("Sign in error:", error);
|
||||
},
|
||||
);
|
||||
}}
|
||||
className="w-full"
|
||||
variant="default"
|
||||
>
|
||||
<Github className="mr-2 h-4 w-4" />
|
||||
Sign in with GitHub
|
||||
</Button>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
<div className="text-sm text-muted-foreground">
|
||||
Signed in as {session.user?.email}
|
||||
</div>
|
||||
|
||||
{session.accessToken && (
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm font-medium">
|
||||
GitHub Access Token:
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<code className="flex-1 p-2 bg-muted rounded text-xs break-all">
|
||||
{session.accessToken}
|
||||
</code>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={handleCopyToken}
|
||||
>
|
||||
{copied ? (
|
||||
<Check className="h-4 w-4" />
|
||||
) : (
|
||||
<Copy className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
This GitHub token will be used as your API key for
|
||||
authenticated Llama Stack requests.
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button onClick={() => router.push("/")} className="flex-1">
|
||||
<Home className="mr-2 h-4 w-4" />
|
||||
Go to Dashboard
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => signOut()}
|
||||
variant="outline"
|
||||
className="flex-1"
|
||||
>
|
||||
Sign out
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
import type { Metadata } from "next";
|
||||
import { ThemeProvider } from "@/components/ui/theme-provider";
|
||||
import { SessionProvider } from "@/components/providers/session-provider";
|
||||
import { Geist, Geist_Mono } from "next/font/google";
|
||||
import { ModeToggle } from "@/components/ui/mode-toggle";
|
||||
import "./globals.css";
|
||||
|
@ -21,11 +22,13 @@ export const metadata: Metadata = {
|
|||
|
||||
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
||||
import { AppSidebar } from "@/components/layout/app-sidebar";
|
||||
import { SignInButton } from "@/components/ui/sign-in-button";
|
||||
|
||||
export default function Layout({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<html lang="en" suppressHydrationWarning>
|
||||
<body className={`${geistSans.variable} ${geistMono.variable} font-sans`}>
|
||||
<SessionProvider>
|
||||
<ThemeProvider
|
||||
attribute="class"
|
||||
defaultTheme="system"
|
||||
|
@ -41,7 +44,8 @@ export default function Layout({ children }: { children: React.ReactNode }) {
|
|||
<SidebarTrigger />
|
||||
</div>
|
||||
<div className="flex-1 text-center"></div>
|
||||
<div className="flex-none">
|
||||
<div className="flex-none flex items-center gap-2">
|
||||
<SignInButton />
|
||||
<ModeToggle />
|
||||
</div>
|
||||
</div>
|
||||
|
@ -49,6 +53,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
|
|||
</main>
|
||||
</SidebarProvider>
|
||||
</ThemeProvider>
|
||||
</SessionProvider>
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
|
|
|
@ -4,11 +4,12 @@ import { useEffect, useState } from "react";
|
|||
import { useParams } from "next/navigation";
|
||||
import { ChatCompletion } from "@/lib/types";
|
||||
import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail";
|
||||
import { client } from "@/lib/client";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
|
||||
export default function ChatCompletionDetailPage() {
|
||||
const params = useParams();
|
||||
const id = params.id as string;
|
||||
const client = useAuthClient();
|
||||
|
||||
const [completionDetail, setCompletionDetail] =
|
||||
useState<ChatCompletion | null>(null);
|
||||
|
@ -45,7 +46,7 @@ export default function ChatCompletionDetailPage() {
|
|||
};
|
||||
|
||||
fetchCompletionDetail();
|
||||
}, [id]);
|
||||
}, [id, client]);
|
||||
|
||||
return (
|
||||
<ChatCompletionDetailView
|
||||
|
|
|
@ -5,11 +5,12 @@ import { useParams } from "next/navigation";
|
|||
import type { ResponseObject } from "llama-stack-client/resources/responses/responses";
|
||||
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
|
||||
import { ResponseDetailView } from "@/components/responses/responses-detail";
|
||||
import { client } from "@/lib/client";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
|
||||
export default function ResponseDetailPage() {
|
||||
const params = useParams();
|
||||
const id = params.id as string;
|
||||
const client = useAuthClient();
|
||||
|
||||
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
|
||||
null,
|
||||
|
@ -109,7 +110,7 @@ export default function ResponseDetailPage() {
|
|||
};
|
||||
|
||||
fetchResponseDetail();
|
||||
}, [id]);
|
||||
}, [id, client]);
|
||||
|
||||
return (
|
||||
<ResponseDetailView
|
||||
|
|
|
@ -12,24 +12,34 @@ jest.mock("next/navigation", () => ({
|
|||
}),
|
||||
}));
|
||||
|
||||
// Mock next-auth
|
||||
jest.mock("next-auth/react", () => ({
|
||||
useSession: () => ({
|
||||
status: "authenticated",
|
||||
data: { accessToken: "mock-token" },
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock helper functions
|
||||
jest.mock("@/lib/truncate-text");
|
||||
jest.mock("@/lib/format-message-content");
|
||||
|
||||
// Mock the client
|
||||
jest.mock("@/lib/client", () => ({
|
||||
client: {
|
||||
// Mock the auth client hook
|
||||
const mockClient = {
|
||||
chat: {
|
||||
completions: {
|
||||
list: jest.fn(),
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: () => mockClient,
|
||||
}));
|
||||
|
||||
// Mock the usePagination hook
|
||||
const mockLoadMore = jest.fn();
|
||||
jest.mock("@/hooks/usePagination", () => ({
|
||||
jest.mock("@/hooks/use-pagination", () => ({
|
||||
usePagination: jest.fn(() => ({
|
||||
data: [],
|
||||
status: "idle",
|
||||
|
@ -47,7 +57,7 @@ import {
|
|||
} from "@/lib/format-message-content";
|
||||
|
||||
// Import the mocked hook
|
||||
import { usePagination } from "@/hooks/usePagination";
|
||||
import { usePagination } from "@/hooks/use-pagination";
|
||||
const mockedUsePagination = usePagination as jest.MockedFunction<
|
||||
typeof usePagination
|
||||
>;
|
||||
|
|
|
@ -10,8 +10,7 @@ import {
|
|||
extractTextFromContentPart,
|
||||
extractDisplayableText,
|
||||
} from "@/lib/format-message-content";
|
||||
import { usePagination } from "@/hooks/usePagination";
|
||||
import { client } from "@/lib/client";
|
||||
import { usePagination } from "@/hooks/use-pagination";
|
||||
|
||||
interface ChatCompletionsTableProps {
|
||||
/** Optional pagination configuration */
|
||||
|
@ -32,12 +31,15 @@ function formatChatCompletionToRow(completion: ChatCompletion): LogTableRow {
|
|||
export function ChatCompletionsTable({
|
||||
paginationOptions,
|
||||
}: ChatCompletionsTableProps) {
|
||||
const fetchFunction = async (params: {
|
||||
const fetchFunction = async (
|
||||
client: ReturnType<typeof import("@/hooks/use-auth-client").useAuthClient>,
|
||||
params: {
|
||||
after?: string;
|
||||
limit: number;
|
||||
model?: string;
|
||||
order?: string;
|
||||
}) => {
|
||||
},
|
||||
) => {
|
||||
const response = await client.chat.completions.list({
|
||||
after: params.after,
|
||||
limit: params.limit,
|
||||
|
|
|
@ -12,7 +12,7 @@ jest.mock("next/navigation", () => ({
|
|||
}));
|
||||
|
||||
// Mock the useInfiniteScroll hook
|
||||
jest.mock("@/hooks/useInfiniteScroll", () => ({
|
||||
jest.mock("@/hooks/use-infinite-scroll", () => ({
|
||||
useInfiniteScroll: jest.fn((onLoadMore, options) => {
|
||||
const ref = React.useRef(null);
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import { useRouter } from "next/navigation";
|
|||
import { useRef } from "react";
|
||||
import { truncateText } from "@/lib/truncate-text";
|
||||
import { PaginationStatus } from "@/lib/types";
|
||||
import { useInfiniteScroll } from "@/hooks/useInfiniteScroll";
|
||||
import { useInfiniteScroll } from "@/hooks/use-infinite-scroll";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
|
|
7
llama_stack/ui/components/providers/session-provider.tsx
Normal file
7
llama_stack/ui/components/providers/session-provider.tsx
Normal file
|
@ -0,0 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { SessionProvider as NextAuthSessionProvider } from "next-auth/react";
|
||||
|
||||
export function SessionProvider({ children }: { children: React.ReactNode }) {
|
||||
return <NextAuthSessionProvider>{children}</NextAuthSessionProvider>;
|
||||
}
|
|
@ -12,21 +12,31 @@ jest.mock("next/navigation", () => ({
|
|||
}),
|
||||
}));
|
||||
|
||||
// Mock next-auth
|
||||
jest.mock("next-auth/react", () => ({
|
||||
useSession: () => ({
|
||||
status: "authenticated",
|
||||
data: { accessToken: "mock-token" },
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock helper functions
|
||||
jest.mock("@/lib/truncate-text");
|
||||
|
||||
// Mock the client
|
||||
jest.mock("@/lib/client", () => ({
|
||||
client: {
|
||||
// Mock the auth client hook
|
||||
const mockClient = {
|
||||
responses: {
|
||||
list: jest.fn(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: () => mockClient,
|
||||
}));
|
||||
|
||||
// Mock the usePagination hook
|
||||
const mockLoadMore = jest.fn();
|
||||
jest.mock("@/hooks/usePagination", () => ({
|
||||
jest.mock("@/hooks/use-pagination", () => ({
|
||||
usePagination: jest.fn(() => ({
|
||||
data: [],
|
||||
status: "idle",
|
||||
|
@ -40,7 +50,7 @@ jest.mock("@/hooks/usePagination", () => ({
|
|||
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
|
||||
|
||||
// Import the mocked hook
|
||||
import { usePagination } from "@/hooks/usePagination";
|
||||
import { usePagination } from "@/hooks/use-pagination";
|
||||
const mockedUsePagination = usePagination as jest.MockedFunction<
|
||||
typeof usePagination
|
||||
>;
|
||||
|
|
|
@ -6,8 +6,7 @@ import {
|
|||
UsePaginationOptions,
|
||||
} from "@/lib/types";
|
||||
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
|
||||
import { usePagination } from "@/hooks/usePagination";
|
||||
import { client } from "@/lib/client";
|
||||
import { usePagination } from "@/hooks/use-pagination";
|
||||
import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses";
|
||||
import {
|
||||
isMessageInput,
|
||||
|
@ -125,12 +124,15 @@ function formatResponseToRow(response: OpenAIResponse): LogTableRow {
|
|||
}
|
||||
|
||||
export function ResponsesTable({ paginationOptions }: ResponsesTableProps) {
|
||||
const fetchFunction = async (params: {
|
||||
const fetchFunction = async (
|
||||
client: ReturnType<typeof import("@/hooks/use-auth-client").useAuthClient>,
|
||||
params: {
|
||||
after?: string;
|
||||
limit: number;
|
||||
model?: string;
|
||||
order?: string;
|
||||
}) => {
|
||||
},
|
||||
) => {
|
||||
const response = await client.responses.list({
|
||||
after: params.after,
|
||||
limit: params.limit,
|
||||
|
|
25
llama_stack/ui/components/ui/sign-in-button.tsx
Normal file
25
llama_stack/ui/components/ui/sign-in-button.tsx
Normal file
|
@ -0,0 +1,25 @@
|
|||
"use client";
|
||||
|
||||
import { User } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { Button } from "./button";
|
||||
|
||||
export function SignInButton() {
|
||||
const { data: session, status } = useSession();
|
||||
|
||||
return (
|
||||
<Button variant="ghost" size="sm" asChild>
|
||||
<Link href="/auth/signin" className="flex items-center">
|
||||
<User className="mr-2 h-4 w-4" />
|
||||
<span>
|
||||
{status === "loading"
|
||||
? "Loading..."
|
||||
: session
|
||||
? session.user?.email || "Signed In"
|
||||
: "Sign In"}
|
||||
</span>
|
||||
</Link>
|
||||
</Button>
|
||||
);
|
||||
}
|
24
llama_stack/ui/hooks/use-auth-client.ts
Normal file
24
llama_stack/ui/hooks/use-auth-client.ts
Normal file
|
@ -0,0 +1,24 @@
|
|||
import { useSession } from "next-auth/react";
|
||||
import { useMemo } from "react";
|
||||
import LlamaStackClient from "llama-stack-client";
|
||||
|
||||
export function useAuthClient() {
|
||||
const { data: session } = useSession();
|
||||
|
||||
const client = useMemo(() => {
|
||||
const clientHostname =
|
||||
typeof window !== "undefined" ? window.location.origin : "";
|
||||
|
||||
const options: any = {
|
||||
baseURL: `${clientHostname}/api`,
|
||||
};
|
||||
|
||||
if (session?.accessToken) {
|
||||
options.apiKey = session.accessToken;
|
||||
}
|
||||
|
||||
return new LlamaStackClient(options);
|
||||
}, [session?.accessToken]);
|
||||
|
||||
return client;
|
||||
}
|
|
@ -2,6 +2,9 @@
|
|||
|
||||
import { useState, useCallback, useEffect, useRef } from "react";
|
||||
import { PaginationStatus, UsePaginationOptions } from "@/lib/types";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
interface PaginationState<T> {
|
||||
data: T[];
|
||||
|
@ -28,13 +31,18 @@ export interface PaginationReturn<T> {
|
|||
}
|
||||
|
||||
interface UsePaginationParams<T> extends UsePaginationOptions {
|
||||
fetchFunction: (params: {
|
||||
fetchFunction: (
|
||||
client: ReturnType<typeof useAuthClient>,
|
||||
params: {
|
||||
after?: string;
|
||||
limit: number;
|
||||
model?: string;
|
||||
order?: string;
|
||||
}) => Promise<PaginationResponse<T>>;
|
||||
},
|
||||
) => Promise<PaginationResponse<T>>;
|
||||
errorMessagePrefix: string;
|
||||
enabled?: boolean;
|
||||
useAuth?: boolean;
|
||||
}
|
||||
|
||||
export function usePagination<T>({
|
||||
|
@ -43,7 +51,12 @@ export function usePagination<T>({
|
|||
order = "desc",
|
||||
fetchFunction,
|
||||
errorMessagePrefix,
|
||||
enabled = true,
|
||||
useAuth = true,
|
||||
}: UsePaginationParams<T>): PaginationReturn<T> {
|
||||
const { status: sessionStatus } = useSession();
|
||||
const client = useAuthClient();
|
||||
const router = useRouter();
|
||||
const [state, setState] = useState<PaginationState<T>>({
|
||||
data: [],
|
||||
status: "loading",
|
||||
|
@ -74,7 +87,7 @@ export function usePagination<T>({
|
|||
error: null,
|
||||
}));
|
||||
|
||||
const response = await fetchFunction({
|
||||
const response = await fetchFunction(client, {
|
||||
after: after || undefined,
|
||||
limit: fetchLimit,
|
||||
...(model && { model }),
|
||||
|
@ -91,6 +104,17 @@ export function usePagination<T>({
|
|||
status: "idle",
|
||||
}));
|
||||
} catch (err) {
|
||||
// Check if it's a 401 unauthorized error
|
||||
if (
|
||||
err &&
|
||||
typeof err === "object" &&
|
||||
"status" in err &&
|
||||
err.status === 401
|
||||
) {
|
||||
router.push("/auth/signin");
|
||||
return;
|
||||
}
|
||||
|
||||
const errorMessage = isInitialLoad
|
||||
? `Failed to load ${errorMessagePrefix}. Please try refreshing the page.`
|
||||
: `Failed to load more ${errorMessagePrefix}. Please try again.`;
|
||||
|
@ -107,7 +131,7 @@ export function usePagination<T>({
|
|||
}));
|
||||
}
|
||||
},
|
||||
[limit, model, order, fetchFunction, errorMessagePrefix],
|
||||
[limit, model, order, fetchFunction, errorMessagePrefix, client, router],
|
||||
);
|
||||
|
||||
/**
|
||||
|
@ -120,17 +144,28 @@ export function usePagination<T>({
|
|||
}
|
||||
}, [fetchData]);
|
||||
|
||||
// Auto-load initial data on mount
|
||||
// Auto-load initial data on mount when enabled
|
||||
useEffect(() => {
|
||||
if (!hasFetchedInitialData.current) {
|
||||
// If using auth, wait for session to load
|
||||
const isAuthReady = !useAuth || sessionStatus !== "loading";
|
||||
const shouldFetch = enabled && isAuthReady;
|
||||
|
||||
if (shouldFetch && !hasFetchedInitialData.current) {
|
||||
hasFetchedInitialData.current = true;
|
||||
fetchData();
|
||||
} else if (!shouldFetch) {
|
||||
// Reset the flag when disabled so it can fetch when re-enabled
|
||||
hasFetchedInitialData.current = false;
|
||||
}
|
||||
}, [fetchData]);
|
||||
}, [fetchData, enabled, useAuth, sessionStatus]);
|
||||
|
||||
// Override status if we're waiting for auth
|
||||
const effectiveStatus =
|
||||
useAuth && sessionStatus === "loading" ? "loading" : state.status;
|
||||
|
||||
return {
|
||||
data: state.data,
|
||||
status: state.status,
|
||||
status: effectiveStatus,
|
||||
hasMore: state.hasMore,
|
||||
error: state.error,
|
||||
loadMore,
|
11
llama_stack/ui/instrumentation.ts
Normal file
11
llama_stack/ui/instrumentation.ts
Normal file
|
@ -0,0 +1,11 @@
|
|||
/**
|
||||
* Next.js Instrumentation
|
||||
* This file is used for initializing monitoring, tracing, or other observability tools.
|
||||
* It runs once when the server starts, before any application code.
|
||||
*
|
||||
* Learn more: https://nextjs.org/docs/app/building-your-application/optimizing/instrumentation
|
||||
*/
|
||||
|
||||
export async function register() {
|
||||
await import("./lib/config-validator");
|
||||
}
|
38
llama_stack/ui/lib/auth.ts
Normal file
38
llama_stack/ui/lib/auth.ts
Normal file
|
@ -0,0 +1,38 @@
|
|||
import { NextAuthOptions } from "next-auth";
|
||||
import GithubProvider from "next-auth/providers/github";
|
||||
|
||||
export const authOptions: NextAuthOptions = {
|
||||
providers: [
|
||||
GithubProvider({
|
||||
clientId: process.env.GITHUB_CLIENT_ID!,
|
||||
clientSecret: process.env.GITHUB_CLIENT_SECRET!,
|
||||
authorization: {
|
||||
params: {
|
||||
scope: "read:user user:email",
|
||||
},
|
||||
},
|
||||
}),
|
||||
],
|
||||
debug: process.env.NODE_ENV === "development",
|
||||
callbacks: {
|
||||
async jwt({ token, account }) {
|
||||
// Persist the OAuth access_token to the token right after signin
|
||||
if (account) {
|
||||
token.accessToken = account.access_token;
|
||||
}
|
||||
return token;
|
||||
},
|
||||
async session({ session, token }) {
|
||||
// Send properties to the client, like an access_token from a provider.
|
||||
session.accessToken = token.accessToken as string;
|
||||
return session;
|
||||
},
|
||||
},
|
||||
pages: {
|
||||
signIn: "/auth/signin",
|
||||
error: "/auth/signin", // Redirect errors to our custom page
|
||||
},
|
||||
session: {
|
||||
strategy: "jwt",
|
||||
},
|
||||
};
|
|
@ -1,6 +0,0 @@
|
|||
import LlamaStackClient from "llama-stack-client";
|
||||
|
||||
export const client = new LlamaStackClient({
|
||||
baseURL:
|
||||
typeof window !== "undefined" ? `${window.location.origin}/api` : "/api",
|
||||
});
|
56
llama_stack/ui/lib/config-validator.ts
Normal file
56
llama_stack/ui/lib/config-validator.ts
Normal file
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Validates environment configuration for the application
|
||||
* This is called during server initialization
|
||||
*/
|
||||
export function validateServerConfig() {
|
||||
if (process.env.NODE_ENV === "development") {
|
||||
console.log("🚀 Starting Llama Stack UI Server...");
|
||||
|
||||
// Check optional configurations
|
||||
const optionalConfigs = {
|
||||
NEXTAUTH_URL: process.env.NEXTAUTH_URL || "http://localhost:8322",
|
||||
LLAMA_STACK_BACKEND_URL:
|
||||
process.env.LLAMA_STACK_BACKEND_URL || "http://localhost:8321",
|
||||
LLAMA_STACK_UI_PORT: process.env.LLAMA_STACK_UI_PORT || "8322",
|
||||
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
|
||||
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
|
||||
};
|
||||
|
||||
console.log("\n📋 Configuration:");
|
||||
console.log(` - NextAuth URL: ${optionalConfigs.NEXTAUTH_URL}`);
|
||||
console.log(` - Backend URL: ${optionalConfigs.LLAMA_STACK_BACKEND_URL}`);
|
||||
console.log(` - UI Port: ${optionalConfigs.LLAMA_STACK_UI_PORT}`);
|
||||
|
||||
// Check GitHub OAuth configuration
|
||||
if (
|
||||
!optionalConfigs.GITHUB_CLIENT_ID ||
|
||||
!optionalConfigs.GITHUB_CLIENT_SECRET
|
||||
) {
|
||||
console.log(
|
||||
"\n📝 GitHub OAuth not configured (authentication features disabled)",
|
||||
);
|
||||
console.log(" To enable GitHub OAuth:");
|
||||
console.log(" 1. Go to https://github.com/settings/applications/new");
|
||||
console.log(
|
||||
" 2. Set Application name: Llama Stack UI (or your preferred name)",
|
||||
);
|
||||
console.log(" 3. Set Homepage URL: http://localhost:8322");
|
||||
console.log(
|
||||
" 4. Set Authorization callback URL: http://localhost:8322/api/auth/callback/github",
|
||||
);
|
||||
console.log(
|
||||
" 5. Create the app and copy the Client ID and Client Secret",
|
||||
);
|
||||
console.log(" 6. Add them to your .env.local file:");
|
||||
console.log(" GITHUB_CLIENT_ID=your_client_id");
|
||||
console.log(" GITHUB_CLIENT_SECRET=your_client_secret");
|
||||
} else {
|
||||
console.log(" - GitHub OAuth: ✅ Configured");
|
||||
}
|
||||
|
||||
console.log("");
|
||||
}
|
||||
}
|
||||
|
||||
// Call this function when the module is imported
|
||||
validateServerConfig();
|
147
llama_stack/ui/package-lock.json
generated
147
llama_stack/ui/package-lock.json
generated
|
@ -18,6 +18,7 @@
|
|||
"llama-stack-client": "0.2.13",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
"next-themes": "^0.4.6",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
|
@ -548,7 +549,6 @@
|
|||
"version": "7.27.1",
|
||||
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.1.tgz",
|
||||
"integrity": "sha512-1x3D2xEk2fRo3PAhwQwu5UubzgiVWSXTBfWpVd2Mx2AzRqJuDJCsgaDVZ7HB5iGzDW1Hl1sWN2mFyKjmR9uAog==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6.9.0"
|
||||
|
@ -2423,6 +2423,15 @@
|
|||
"node": ">=12.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@panva/hkdf": {
|
||||
"version": "1.2.1",
|
||||
"resolved": "https://registry.npmjs.org/@panva/hkdf/-/hkdf-1.2.1.tgz",
|
||||
"integrity": "sha512-6oclG6Y3PiDFcoyk8srjLfVKyMfVCKJ27JwNPViuXziFpmdz+MZnZN/aKY0JGXgYuO/VghU0jcOAZgWXZ1Dmrw==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/panva"
|
||||
}
|
||||
},
|
||||
"node_modules/@pkgr/core": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@pkgr/core/-/core-0.2.4.tgz",
|
||||
|
@ -5279,7 +5288,6 @@
|
|||
"version": "0.7.2",
|
||||
"resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz",
|
||||
"integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.6"
|
||||
|
@ -9036,6 +9044,15 @@
|
|||
"jiti": "lib/jiti-cli.mjs"
|
||||
}
|
||||
},
|
||||
"node_modules/jose": {
|
||||
"version": "4.15.9",
|
||||
"resolved": "https://registry.npmjs.org/jose/-/jose-4.15.9.tgz",
|
||||
"integrity": "sha512-1vUQX+IdDMVPj4k8kOxgUqlcK518yluMuGZwqlr44FS1ppZB/5GWh4rZG89erpOBOJjU/OBsnCVFfapsRz6nEA==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/panva"
|
||||
}
|
||||
},
|
||||
"node_modules/js-tokens": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz",
|
||||
|
@ -9949,6 +9966,38 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"node_modules/next-auth": {
|
||||
"version": "4.24.11",
|
||||
"resolved": "https://registry.npmjs.org/next-auth/-/next-auth-4.24.11.tgz",
|
||||
"integrity": "sha512-pCFXzIDQX7xmHFs4KVH4luCjaCbuPRtZ9oBUjUhOk84mZ9WVPf94n87TxYI4rSRf9HmfHEF8Yep3JrYDVOo3Cw==",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.20.13",
|
||||
"@panva/hkdf": "^1.0.2",
|
||||
"cookie": "^0.7.0",
|
||||
"jose": "^4.15.5",
|
||||
"oauth": "^0.9.15",
|
||||
"openid-client": "^5.4.0",
|
||||
"preact": "^10.6.3",
|
||||
"preact-render-to-string": "^5.1.19",
|
||||
"uuid": "^8.3.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@auth/core": "0.34.2",
|
||||
"next": "^12.2.5 || ^13 || ^14 || ^15",
|
||||
"nodemailer": "^6.6.5",
|
||||
"react": "^17.0.2 || ^18 || ^19",
|
||||
"react-dom": "^17.0.2 || ^18 || ^19"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@auth/core": {
|
||||
"optional": true
|
||||
},
|
||||
"nodemailer": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/next-themes": {
|
||||
"version": "0.4.6",
|
||||
"resolved": "https://registry.npmjs.org/next-themes/-/next-themes-0.4.6.tgz",
|
||||
|
@ -10071,6 +10120,12 @@
|
|||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/oauth": {
|
||||
"version": "0.9.15",
|
||||
"resolved": "https://registry.npmjs.org/oauth/-/oauth-0.9.15.tgz",
|
||||
"integrity": "sha512-a5ERWK1kh38ExDEfoO6qUHJb32rd7aYmPHuyCu3Fta/cnICvYmgd2uhuKXvPD+PXB+gCEYYEaQdIRAjCOwAKNA==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/object-assign": {
|
||||
"version": "4.1.1",
|
||||
"resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz",
|
||||
|
@ -10081,6 +10136,15 @@
|
|||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/object-hash": {
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz",
|
||||
"integrity": "sha512-gScRMn0bS5fH+IuwyIFgnh9zBdo4DV+6GhygmWM9HyNJSgS0hScp1f5vjtm7oIIOiT9trXrShAkLFSc2IqKNgw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/object-inspect": {
|
||||
"version": "1.13.4",
|
||||
"resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz",
|
||||
|
@ -10194,6 +10258,15 @@
|
|||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/oidc-token-hash": {
|
||||
"version": "5.1.0",
|
||||
"resolved": "https://registry.npmjs.org/oidc-token-hash/-/oidc-token-hash-5.1.0.tgz",
|
||||
"integrity": "sha512-y0W+X7Ppo7oZX6eovsRkuzcSM40Bicg2JEJkDJ4irIt1wsYAP5MLSNv+QAogO8xivMffw/9OvV3um1pxXgt1uA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": "^10.13.0 || >=12.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/on-finished": {
|
||||
"version": "2.4.1",
|
||||
"resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz",
|
||||
|
@ -10233,6 +10306,39 @@
|
|||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/openid-client": {
|
||||
"version": "5.7.1",
|
||||
"resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.7.1.tgz",
|
||||
"integrity": "sha512-jDBPgSVfTnkIh71Hg9pRvtJc6wTwqjRkN88+gCFtYWrlP4Yx2Dsrow8uPi3qLr/aeymPF3o2+dS+wOpglK04ew==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"jose": "^4.15.9",
|
||||
"lru-cache": "^6.0.0",
|
||||
"object-hash": "^2.2.0",
|
||||
"oidc-token-hash": "^5.0.3"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/panva"
|
||||
}
|
||||
},
|
||||
"node_modules/openid-client/node_modules/lru-cache": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz",
|
||||
"integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"yallist": "^4.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/openid-client/node_modules/yallist": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz",
|
||||
"integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==",
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/optionator": {
|
||||
"version": "0.9.4",
|
||||
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
|
||||
|
@ -10560,6 +10666,34 @@
|
|||
"node": "^10 || ^12 || >=14"
|
||||
}
|
||||
},
|
||||
"node_modules/preact": {
|
||||
"version": "10.26.9",
|
||||
"resolved": "https://registry.npmjs.org/preact/-/preact-10.26.9.tgz",
|
||||
"integrity": "sha512-SSjF9vcnF27mJK1XyFMNJzFd5u3pQiATFqoaDy03XuN00u4ziveVVEGt5RKJrDR8MHE/wJo9Nnad56RLzS2RMA==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/preact"
|
||||
}
|
||||
},
|
||||
"node_modules/preact-render-to-string": {
|
||||
"version": "5.2.6",
|
||||
"resolved": "https://registry.npmjs.org/preact-render-to-string/-/preact-render-to-string-5.2.6.tgz",
|
||||
"integrity": "sha512-JyhErpYOvBV1hEPwIxc/fHWXPfnEGdRKxc8gFdAZ7XV4tlzyzG847XAyEZqoDnynP88akM4eaHcSOzNcLWFguw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"pretty-format": "^3.8.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"preact": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/preact-render-to-string/node_modules/pretty-format": {
|
||||
"version": "3.8.0",
|
||||
"resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-3.8.0.tgz",
|
||||
"integrity": "sha512-WuxUnVtlWL1OfZFQFuqvnvs6MiAGk9UNsBostyBOB0Is9wb5uRESevA6rnl/rkksXaGX3GzZhPup5d6Vp1nFew==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/prelude-ls": {
|
||||
"version": "1.2.1",
|
||||
"resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz",
|
||||
|
@ -12409,6 +12543,15 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"node_modules/uuid": {
|
||||
"version": "8.3.2",
|
||||
"resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz",
|
||||
"integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==",
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"uuid": "dist/bin/uuid"
|
||||
}
|
||||
},
|
||||
"node_modules/v8-compile-cache-lib": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
"llama-stack-client": "0.2.13",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
"next-themes": "^0.4.6",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
|
|
7
llama_stack/ui/types/next-auth.d.ts
vendored
Normal file
7
llama_stack/ui/types/next-auth.d.ts
vendored
Normal file
|
@ -0,0 +1,7 @@
|
|||
import NextAuth from "next-auth";
|
||||
|
||||
declare module "next-auth" {
|
||||
interface Session {
|
||||
accessToken?: string;
|
||||
}
|
||||
}
|
|
@ -86,6 +86,7 @@ unit = [
|
|||
"sqlalchemy[asyncio]>=2.0.41",
|
||||
"blobfile",
|
||||
"faiss-cpu",
|
||||
"pymilvus>=2.5.12",
|
||||
]
|
||||
# These are the core dependencies required for running integration tests. They are shared across all
|
||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||
|
@ -106,6 +107,7 @@ test = [
|
|||
"sqlalchemy",
|
||||
"sqlalchemy[asyncio]>=2.0.41",
|
||||
"requests",
|
||||
"pymilvus>=2.5.12",
|
||||
]
|
||||
docs = [
|
||||
"setuptools",
|
||||
|
|
5
tests/integration/providers/utils/__init__.py
Normal file
5
tests/integration/providers/utils/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
tests/integration/providers/utils/sqlstore/__init__.py
Normal file
5
tests/integration/providers/utils/sqlstore/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,173 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.access_control.access_control import default_policy
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
|
||||
|
||||
|
||||
def get_postgres_config():
|
||||
"""Get PostgreSQL configuration if tests are enabled."""
|
||||
return PostgresSqlStoreConfig(
|
||||
host=os.environ.get("POSTGRES_HOST", "localhost"),
|
||||
port=int(os.environ.get("POSTGRES_PORT", "5432")),
|
||||
db=os.environ.get("POSTGRES_DB", "llamastack"),
|
||||
user=os.environ.get("POSTGRES_USER", "llamastack"),
|
||||
password=os.environ.get("POSTGRES_PASSWORD", "llamastack"),
|
||||
)
|
||||
|
||||
|
||||
def get_sqlite_config():
|
||||
"""Get SQLite configuration with temporary database."""
|
||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
tmp_file.close()
|
||||
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"backend_config",
|
||||
[
|
||||
pytest.param(
|
||||
("postgres", get_postgres_config),
|
||||
marks=pytest.mark.skipif(
|
||||
not os.environ.get("ENABLE_POSTGRES_TESTS"),
|
||||
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
|
||||
),
|
||||
id="postgres",
|
||||
),
|
||||
pytest.param(("sqlite", get_sqlite_config), id="sqlite"),
|
||||
],
|
||||
)
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_json_comparison(mock_get_authenticated_user, backend_config):
|
||||
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
|
||||
backend_name, config_func = backend_config
|
||||
|
||||
# Handle different config types
|
||||
if backend_name == "postgres":
|
||||
config = config_func()
|
||||
cleanup_path = None
|
||||
else: # sqlite
|
||||
config, cleanup_path = config_func()
|
||||
|
||||
try:
|
||||
base_sqlstore = SqlAlchemySqlStoreImpl(config)
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
||||
|
||||
# Create test table
|
||||
table_name = f"test_json_comparison_{backend_name}"
|
||||
await authorized_store.create_table(
|
||||
table=table_name,
|
||||
schema={
|
||||
"id": ColumnType.STRING,
|
||||
"data": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Test with no authenticated user (should handle JSON null comparison)
|
||||
mock_get_authenticated_user.return_value = None
|
||||
|
||||
# Insert some test data
|
||||
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
||||
|
||||
# Test fetching with no user - should not error on JSON comparison
|
||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["id"] == "1"
|
||||
assert result.data[0]["access_attributes"] is None
|
||||
|
||||
# Test with authenticated user
|
||||
test_user = User("test-user", {"roles": ["admin"]})
|
||||
mock_get_authenticated_user.return_value = test_user
|
||||
|
||||
# Insert data with user attributes
|
||||
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
||||
|
||||
# Fetch all - admin should see both
|
||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||
assert len(result.data) == 2
|
||||
|
||||
# Test with non-admin user
|
||||
regular_user = User("regular-user", {"roles": ["user"]})
|
||||
mock_get_authenticated_user.return_value = regular_user
|
||||
|
||||
# Should only see public record
|
||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["id"] == "1"
|
||||
|
||||
# Test the category missing branch: user with multiple attributes
|
||||
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
|
||||
mock_get_authenticated_user.return_value = multi_user
|
||||
|
||||
# Insert record with multi-user (has both roles and teams)
|
||||
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
|
||||
|
||||
# Test different user types to create records with different attribute patterns
|
||||
# Record with only roles (teams category will be missing)
|
||||
roles_only_user = User("roles-user", {"roles": ["admin"]})
|
||||
mock_get_authenticated_user.return_value = roles_only_user
|
||||
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
|
||||
|
||||
# Record with only teams (roles category will be missing)
|
||||
teams_only_user = User("teams-user", {"teams": ["dev"]})
|
||||
mock_get_authenticated_user.return_value = teams_only_user
|
||||
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
|
||||
|
||||
# Record with different roles/teams (shouldn't match our test user)
|
||||
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
|
||||
mock_get_authenticated_user.return_value = different_user
|
||||
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
|
||||
|
||||
# Now test with the multi-user who has both roles=admin and teams=dev
|
||||
mock_get_authenticated_user.return_value = multi_user
|
||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||
|
||||
# Should see:
|
||||
# - public record (1) - no access_attributes
|
||||
# - admin record (2) - user matches roles=admin, teams missing (allowed)
|
||||
# - multi_user record (3) - user matches both roles=admin and teams=dev
|
||||
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
|
||||
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
|
||||
# Should NOT see:
|
||||
# - different_user record (6) - user doesn't match roles=user or teams=qa
|
||||
expected_ids = {"1", "2", "3", "4", "5"}
|
||||
actual_ids = {record["id"] for record in result.data}
|
||||
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
|
||||
|
||||
# Verify the category missing logic specifically
|
||||
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
|
||||
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
|
||||
assert category_test_ids == {"4", "5"}, (
|
||||
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up records
|
||||
for record_id in ["1", "2", "3", "4", "5", "6"]:
|
||||
try:
|
||||
await base_sqlstore.delete(table_name, {"id": record_id})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up temporary SQLite database file if needed
|
||||
if cleanup_path:
|
||||
try:
|
||||
os.unlink(cleanup_path)
|
||||
except OSError:
|
||||
pass
|
|
@ -0,0 +1,420 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from pymilvus import Collection, MilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX, MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
# TODO: Refactor these to be for inline vector-io providers
|
||||
MILVUS_ALIAS = "test_milvus"
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_inference_api(embedding_dimension):
|
||||
class MockInferenceAPI:
|
||||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
|
||||
|
||||
return MockInferenceAPI()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def unique_kvstore_config(tmp_path_factory):
|
||||
# Generate a unique filename for this test
|
||||
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
return SqliteKVStoreConfig(db_path=db_path)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_milvus.db")
|
||||
client = MilvusClient(db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
|
||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||
index.db_path = db_path
|
||||
yield index
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def milvus_vec_adapter(milvus_vec_index, mock_inference_api):
|
||||
config = MilvusVectorIOConfig(
|
||||
db_path=milvus_vec_index.db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = MilvusVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=adapter.metadata_collection_name,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_contains_initial_collection(milvus_vec_adapter):
|
||||
coll_name = milvus_vec_adapter.metadata_collection_name
|
||||
assert coll_name in milvus_vec_adapter.cache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(milvus_vec_index, sample_chunks, sample_embeddings):
|
||||
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
resp = await milvus_vec_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
|
||||
assert resp.chunks[0].content == sample_chunks[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_vector(milvus_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
|
||||
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_emb = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
resp = await milvus_vec_index.query_vector(query_emb, k=2, score_threshold=0.0)
|
||||
assert isinstance(resp, QueryChunksResponse)
|
||||
assert len(resp.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_conflict(milvus_vec_index, sample_chunks, embedding_dimension):
|
||||
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
|
||||
await milvus_vec_index.add_chunks(sample_chunks, embeddings)
|
||||
coll = Collection(milvus_vec_index.collection_name, using=MILVUS_ALIAS)
|
||||
ids = coll.query(expr="id >= 0", output_fields=["id"], timeout=30)
|
||||
flat_ids = [i["id"] for i in ids]
|
||||
assert len(flat_ids) == len(set(flat_ids))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_config):
|
||||
kvstore = await kvstore_impl(unique_kvstore_config)
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
metadata={"test_key": "test_value"},
|
||||
)
|
||||
test_vector_db_data = vector_db.model_dump_json()
|
||||
await kvstore.set(f"{VECTOR_DBS_PREFIX}test_db", test_vector_db_data)
|
||||
tmp_milvus_vec_adapter = MilvusVectorIOAdapter(
|
||||
config=MilvusVectorIOConfig(
|
||||
db_path=milvus_vec_index.db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
),
|
||||
inference_api=None,
|
||||
files_api=None,
|
||||
)
|
||||
await tmp_milvus_vec_adapter.initialize()
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test_db",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
test_vector_db_data = vector_db.model_dump_json()
|
||||
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
|
||||
|
||||
assert milvus_vec_index.client is not None
|
||||
assert isinstance(milvus_vec_index.client, MilvusClient)
|
||||
assert tmp_milvus_vec_adapter.cache is not None
|
||||
# registering a vector won't update the cache or openai_vector_store collection name
|
||||
assert (
|
||||
tmp_milvus_vec_adapter.metadata_collection_name not in tmp_milvus_vec_adapter.cache
|
||||
or tmp_milvus_vec_adapter.openai_vector_stores
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_across_adapter_restarts(
|
||||
tmp_path, milvus_vec_index, mock_inference_api, unique_kvstore_config
|
||||
):
|
||||
adapter1 = MilvusVectorIOAdapter(
|
||||
config=MilvusVectorIOConfig(db_path=milvus_vec_index.db_path, kvstore=unique_kvstore_config),
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter1.initialize()
|
||||
dummy = VectorDB(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await adapter1.register_vector_db(dummy)
|
||||
await adapter1.shutdown()
|
||||
|
||||
await adapter1.initialize()
|
||||
assert "foo_db" in adapter1.cache
|
||||
await adapter1.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_and_unregister_vector_db(milvus_vec_adapter):
|
||||
try:
|
||||
connections.disconnect(MILVUS_ALIAS)
|
||||
except Exception as _:
|
||||
pass
|
||||
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_adapter.config.db_path)
|
||||
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||
dummy = VectorDB(
|
||||
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
|
||||
await milvus_vec_adapter.register_vector_db(dummy)
|
||||
assert dummy.identifier in milvus_vec_adapter.cache
|
||||
|
||||
if dummy.identifier in milvus_vec_adapter.cache:
|
||||
index = milvus_vec_adapter.cache[dummy.identifier].index
|
||||
if hasattr(index, "client") and hasattr(index.client, "_using"):
|
||||
index.client._using = MILVUS_ALIAS
|
||||
|
||||
await milvus_vec_adapter.unregister_vector_db(dummy.identifier)
|
||||
assert dummy.identifier not in milvus_vec_adapter.cache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_unregistered_raises(milvus_vec_adapter):
|
||||
fake_emb = np.zeros(8, dtype=np.float32)
|
||||
with pytest.raises(AttributeError):
|
||||
await milvus_vec_adapter.query_chunks("no_such_db", fake_emb)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_chunks_calls_underlying_index(milvus_vec_adapter):
|
||||
fake_index = AsyncMock()
|
||||
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||
|
||||
chunks = ["chunk1", "chunk2"]
|
||||
await milvus_vec_adapter.insert_chunks("db1", chunks)
|
||||
|
||||
fake_index.insert_chunks.assert_awaited_once_with(chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_chunks_missing_db_raises(milvus_vec_adapter):
|
||||
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await milvus_vec_adapter.insert_chunks("db_not_exist", [])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_calls_underlying_index_and_returns(milvus_vec_adapter):
|
||||
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
||||
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
||||
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||
|
||||
response = await milvus_vec_adapter.query_chunks("db1", "my_query", {"param": 1})
|
||||
|
||||
fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1})
|
||||
assert response is expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_missing_db_raises(milvus_vec_adapter):
|
||||
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await milvus_vec_adapter.query_chunks("db_missing", "q", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_openai_vector_store(milvus_vec_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
|
||||
assert openai_vector_store["id"] in milvus_vec_adapter.openai_vector_stores
|
||||
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_openai_vector_store(milvus_vec_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
openai_vector_store["description"] = "Updated description"
|
||||
await milvus_vec_adapter._update_openai_vector_store(store_id, openai_vector_store)
|
||||
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_openai_vector_store(milvus_vec_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
await milvus_vec_adapter._delete_openai_vector_store_from_storage(store_id)
|
||||
assert openai_vector_store["id"] not in milvus_vec_adapter.openai_vector_stores
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_openai_vector_stores(milvus_vec_adapter):
|
||||
store_id = "vs_1234"
|
||||
openai_vector_store = {
|
||||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
|
||||
loaded_stores = await milvus_vec_adapter._load_openai_vector_stores()
|
||||
assert loaded_stores[store_id] == openai_vector_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
||||
file_info = {
|
||||
"id": file_id,
|
||||
"status": "completed",
|
||||
"vector_store_id": store_id,
|
||||
"attributes": {},
|
||||
"filename": "test_file.txt",
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
|
||||
file_contents = [
|
||||
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||
]
|
||||
|
||||
# validating we don't raise an exception
|
||||
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
||||
file_info = {
|
||||
"id": file_id,
|
||||
"status": "completed",
|
||||
"vector_store_id": store_id,
|
||||
"attributes": {},
|
||||
"filename": "test_file.txt",
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
|
||||
file_contents = [
|
||||
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||
]
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
|
||||
updated_file_info = file_info.copy()
|
||||
updated_file_info["filename"] = "updated_test_file.txt"
|
||||
|
||||
await milvus_vec_adapter._update_openai_vector_store_file(
|
||||
store_id,
|
||||
file_id,
|
||||
updated_file_info,
|
||||
)
|
||||
|
||||
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file(store_id, file_id)
|
||||
assert loaded_contents == updated_file_info
|
||||
assert loaded_contents != file_info
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
||||
file_info = {
|
||||
"id": file_id,
|
||||
"status": "completed",
|
||||
"vector_store_id": store_id,
|
||||
"attributes": {},
|
||||
"filename": "test_file.txt",
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
|
||||
file_contents = [
|
||||
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||
]
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
|
||||
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||
assert loaded_contents == file_contents
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, tmp_path_factory):
|
||||
store_id = "vs_1234"
|
||||
file_id = "file_1234"
|
||||
|
||||
file_info = {
|
||||
"id": file_id,
|
||||
"status": "completed",
|
||||
"vector_store_id": store_id,
|
||||
"attributes": {},
|
||||
"filename": "test_file.txt",
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
|
||||
file_contents = [
|
||||
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
|
||||
]
|
||||
|
||||
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||
await milvus_vec_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
|
||||
|
||||
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
|
||||
assert loaded_contents == []
|
|
@ -11,10 +11,16 @@ import pytest
|
|||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationConfig,
|
||||
AuthProviderType,
|
||||
CustomAuthConfig,
|
||||
OAuth2IntrospectionConfig,
|
||||
OAuth2JWKSConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
)
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.server.auth_providers import (
|
||||
AuthProviderType,
|
||||
get_attributes_from_claims,
|
||||
)
|
||||
|
||||
|
@ -61,24 +67,11 @@ def invalid_token():
|
|||
def http_app(mock_auth_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def k8s_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -94,11 +87,6 @@ def http_client(http_app):
|
|||
return TestClient(http_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def k8s_client(k8s_app):
|
||||
return TestClient(k8s_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scope():
|
||||
return {
|
||||
|
@ -117,18 +105,11 @@ def mock_scope():
|
|||
def mock_http_middleware(mock_auth_endpoint):
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.CUSTOM,
|
||||
config={"endpoint": mock_auth_endpoint},
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_k8s_middleware():
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.KUBERNETES,
|
||||
config={"api_server_url": "https://kubernetes.default.svc"},
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
|
||||
|
@ -161,13 +142,14 @@ async def mock_post_exception(*args, **kwargs):
|
|||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "validated by mock-auth-service" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format(http_client):
|
||||
response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success)
|
||||
|
@ -262,14 +244,14 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
|||
def oauth2_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"key_recheck_period": "3600",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
jwks=OAuth2JWKSConfig(
|
||||
uri="http://mock-authz-service/token/introspect",
|
||||
),
|
||||
audience="llama-stack",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -288,13 +270,14 @@ def oauth2_client(oauth2_app):
|
|||
def test_missing_auth_header_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_jwks_response(*args, **kwargs):
|
||||
|
@ -358,15 +341,16 @@ async def mock_auth_jwks_response(*args, **kwargs):
|
|||
def oauth2_app_with_jwks_token():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"key_recheck_period": "3600",
|
||||
"token": "my-jwks-token",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
jwks=OAuth2JWKSConfig(
|
||||
uri="http://mock-authz-service/token/introspect",
|
||||
key_recheck_period=3600,
|
||||
token="my-jwks-token",
|
||||
),
|
||||
audience="llama-stack",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -449,11 +433,15 @@ def mock_introspection_endpoint():
|
|||
def introspection_app(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
|
||||
},
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
introspection=OAuth2IntrospectionConfig(
|
||||
url=mock_introspection_endpoint,
|
||||
client_id="myclient",
|
||||
client_secret="abcdefg",
|
||||
),
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -468,22 +456,22 @@ def introspection_app(mock_introspection_endpoint):
|
|||
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": None,
|
||||
"introspection": {
|
||||
"url": mock_introspection_endpoint,
|
||||
"client_id": "myclient",
|
||||
"client_secret": "abcdefg",
|
||||
"send_secret_in_body": "true",
|
||||
},
|
||||
"claims_mapping": {
|
||||
provider_config=OAuth2TokenAuthConfig(
|
||||
type=AuthProviderType.OAUTH2_TOKEN,
|
||||
introspection=OAuth2IntrospectionConfig(
|
||||
url=mock_introspection_endpoint,
|
||||
client_id="myclient",
|
||||
client_secret="abcdefg",
|
||||
send_secret_in_body=True,
|
||||
),
|
||||
claims_mapping={
|
||||
"sub": "roles",
|
||||
"scope": "roles",
|
||||
"groups": "teams",
|
||||
"aud": "namespaces",
|
||||
},
|
||||
},
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
|
@ -507,13 +495,14 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
|
|||
def test_missing_auth_header_introspection(introspection_client):
|
||||
response = introspection_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_introspection(introspection_client):
|
||||
response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_introspection_active(*args, **kwargs):
|
||||
|
|
200
tests/unit/server/test_auth_github.py
Normal file
200
tests/unit/server/test_auth_github.py
Normal file
|
@ -0,0 +1,200 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code != 200:
|
||||
# Create a mock request for the HTTPStatusError
|
||||
mock_request = httpx.Request("GET", "https://api.github.com/user")
|
||||
raise httpx.HTTPStatusError(f"HTTP error: {self.status_code}", request=mock_request, response=self)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_token_app():
|
||||
app = FastAPI()
|
||||
|
||||
# Configure GitHub token auth
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config=GitHubTokenAuthConfig(
|
||||
type=AuthProviderType.GITHUB_TOKEN,
|
||||
github_api_base_url="https://api.github.com",
|
||||
claims_mapping={
|
||||
"login": "username",
|
||||
"id": "user_id",
|
||||
"organizations": "teams",
|
||||
},
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
|
||||
# Add auth middleware
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_token_client(github_token_app):
|
||||
return TestClient(github_token_app)
|
||||
|
||||
|
||||
def test_authenticated_endpoint_without_token(github_token_client):
|
||||
"""Test accessing protected endpoint without token"""
|
||||
response = github_token_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
assert "GitHub access token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_authenticated_endpoint_with_invalid_bearer_format(github_token_client):
|
||||
"""Test accessing protected endpoint with invalid bearer format"""
|
||||
response = github_token_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||
def test_authenticated_endpoint_with_valid_github_token(mock_client_class, github_token_client):
|
||||
"""Test accessing protected endpoint with valid GitHub token"""
|
||||
# Mock the GitHub API responses
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Mock successful user API response
|
||||
mock_client.get.side_effect = [
|
||||
MockResponse(
|
||||
200,
|
||||
{
|
||||
"login": "testuser",
|
||||
"id": 12345,
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
),
|
||||
MockResponse(
|
||||
200,
|
||||
[
|
||||
{"login": "test-org-1"},
|
||||
{"login": "test-org-2"},
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
response = github_token_client.get("/test", headers={"Authorization": "Bearer github_token_123"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Authentication successful"
|
||||
|
||||
# Verify the GitHub API was called correctly
|
||||
assert mock_client.get.call_count == 1
|
||||
calls = mock_client.get.call_args_list
|
||||
assert calls[0][0][0] == "https://api.github.com/user"
|
||||
|
||||
# Check authorization header was passed
|
||||
assert calls[0][1]["headers"]["Authorization"] == "Bearer github_token_123"
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
|
||||
"""Test accessing protected endpoint with invalid GitHub token"""
|
||||
# Mock the GitHub API to return 401 Unauthorized
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Mock failed user API response
|
||||
mock_client.get.return_value = MockResponse(401, {"message": "Bad credentials"})
|
||||
|
||||
response = github_token_client.get("/test", headers={"Authorization": "Bearer invalid_token"})
|
||||
assert response.status_code == 401
|
||||
assert (
|
||||
"GitHub token validation failed. Please check your token and try again." in response.json()["error"]["message"]
|
||||
)
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
||||
def test_github_enterprise_support(mock_client_class):
|
||||
"""Test GitHub Enterprise support with custom API base URL"""
|
||||
app = FastAPI()
|
||||
|
||||
# Configure GitHub token auth with enterprise URL
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config=GitHubTokenAuthConfig(
|
||||
type=AuthProviderType.GITHUB_TOKEN,
|
||||
github_api_base_url="https://github.enterprise.com/api/v3",
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Mock the GitHub Enterprise API responses
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Mock successful user API response
|
||||
mock_client.get.side_effect = [
|
||||
MockResponse(
|
||||
200,
|
||||
{
|
||||
"login": "enterprise_user",
|
||||
"id": 99999,
|
||||
"email": "user@enterprise.com",
|
||||
},
|
||||
),
|
||||
MockResponse(
|
||||
200,
|
||||
[
|
||||
{"login": "enterprise-org"},
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
response = client.get("/test", headers={"Authorization": "Bearer enterprise_token"})
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the correct GitHub Enterprise URLs were called
|
||||
assert mock_client.get.call_count == 1
|
||||
calls = mock_client.get.call_args_list
|
||||
assert calls[0][0][0] == "https://github.enterprise.com/api/v3/user"
|
||||
|
||||
|
||||
def test_github_token_auth_error_message_format(github_token_client):
|
||||
"""Test that the error message for missing auth is properly formatted"""
|
||||
response = github_token_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
error_data = response.json()
|
||||
assert "error" in error_data
|
||||
assert "message" in error_data["error"]
|
||||
assert "Authentication required" in error_data["error"]["message"]
|
||||
assert "https://docs.github.com" in error_data["error"]["message"] # Contains link to GitHub docs
|
|
@ -104,19 +104,17 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
|||
|
||||
# Test scenarios with different access control patterns
|
||||
test_scenarios = [
|
||||
# Scenario 1: Public record (no access control)
|
||||
# Scenario 1: Public record (no access control - represents None user insert)
|
||||
{"id": "1", "name": "public", "access_attributes": None},
|
||||
# Scenario 2: Empty access control (should be treated as public)
|
||||
{"id": "2", "name": "empty", "access_attributes": {}},
|
||||
# Scenario 3: Record with roles requirement
|
||||
{"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
|
||||
# Scenario 4: Record with multiple attribute categories
|
||||
{"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
|
||||
# Scenario 5: Record with teams only (missing roles category)
|
||||
{"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
|
||||
# Scenario 6: Record with roles and projects
|
||||
# Scenario 2: Record with roles requirement
|
||||
{"id": "2", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
|
||||
# Scenario 3: Record with multiple attribute categories
|
||||
{"id": "3", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
|
||||
# Scenario 4: Record with teams only (missing roles category)
|
||||
{"id": "4", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
|
||||
# Scenario 5: Record with roles and projects
|
||||
{
|
||||
"id": "6",
|
||||
"id": "5",
|
||||
"name": "admin-project-x",
|
||||
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
|
||||
},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue