Merge branch 'main' into vectordb_name

This commit is contained in:
Francisco Arceo 2025-07-08 16:45:50 -04:00 committed by GitHub
commit b103ee9eb8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
54 changed files with 4277 additions and 1773 deletions

View file

@ -73,9 +73,12 @@ jobs:
server: server:
port: 8321 port: 8321
EOF EOF
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml yq eval '.server.auth.provider_config.type = "${{ matrix.auth-provider }}"' -i $run_dir/run.yaml
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml yq eval '.server.auth.provider_config.tls_cafile = "${{ env.KUBERNETES_CA_CERT_PATH }}"' -i $run_dir/run.yaml
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml yq eval '.server.auth.provider_config.issuer = "${{ env.KUBERNETES_ISSUER }}"' -i $run_dir/run.yaml
yq eval '.server.auth.provider_config.audience = "${{ env.KUBERNETES_AUDIENCE }}"' -i $run_dir/run.yaml
yq eval '.server.auth.provider_config.jwks.uri = "${{ env.KUBERNETES_API_SERVER_URL }}"' -i $run_dir/run.yaml
yq eval '.server.auth.provider_config.jwks.token = "${{ env.TOKEN }}"' -i $run_dir/run.yaml
cat $run_dir/run.yaml cat $run_dir/run.yaml
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &

View file

@ -0,0 +1,70 @@
name: SqlStore Integration Tests
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
- 'llama_stack/providers/utils/sqlstore/**'
- 'tests/integration/sqlstore/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test-postgres:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12", "3.13"]
fail-fast: false
services:
postgres:
image: postgres:15
env:
POSTGRES_USER: llamastack
POSTGRES_PASSWORD: llamastack
POSTGRES_DB: llamastack
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install dependencies
uses: ./.github/actions/setup-runner
with:
python-version: ${{ matrix.python-version }}
- name: Run SqlStore Integration Tests
env:
ENABLE_POSTGRES_TESTS: "true"
POSTGRES_HOST: localhost
POSTGRES_PORT: 5432
POSTGRES_DB: llamastack
POSTGRES_USER: llamastack
POSTGRES_PASSWORD: llamastack
run: |
uv run pytest -sv tests/integration/providers/utils/sqlstore/
- name: Upload test logs
if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: postgres-test-logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.python-version }}
path: |
*.log
retention-days: 1

View file

@ -56,8 +56,8 @@ shields: []
server: server:
port: 8321 port: 8321
auth: auth:
provider_type: "oauth2_token" provider_config:
config: type: "oauth2_token"
jwks: jwks:
uri: "https://my-token-issuing-svc.com/jwks" uri: "https://my-token-issuing-svc.com/jwks"
``` ```
@ -226,6 +226,8 @@ server:
### Authentication Configuration ### Authentication Configuration
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
The `auth` section configures authentication for the server. When configured, all API requests must include a valid Bearer token in the Authorization header: The `auth` section configures authentication for the server. When configured, all API requests must include a valid Bearer token in the Authorization header:
``` ```
@ -240,8 +242,8 @@ The server can be configured to use service account tokens for authorization, va
```yaml ```yaml
server: server:
auth: auth:
provider_type: "oauth2_token" provider_config:
config: type: "oauth2_token"
jwks: jwks:
uri: "https://kubernetes.default.svc:8443/openid/v1/jwks" uri: "https://kubernetes.default.svc:8443/openid/v1/jwks"
token: "${env.TOKEN:+}" token: "${env.TOKEN:+}"
@ -325,13 +327,25 @@ You can easily validate a request by running:
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
``` ```
#### GitHub Token Provider
Validates GitHub personal access tokens or OAuth tokens directly:
```yaml
server:
auth:
provider_config:
type: "github_token"
github_api_base_url: "https://api.github.com" # Or GitHub Enterprise URL
```
The provider fetches user information from GitHub and maps it to access attributes based on the `claims_mapping` configuration.
#### Custom Provider #### Custom Provider
Validates tokens against a custom authentication endpoint: Validates tokens against a custom authentication endpoint:
```yaml ```yaml
server: server:
auth: auth:
provider_type: "custom" provider_config:
config: type: "custom"
endpoint: "https://auth.example.com/validate" # URL of the auth endpoint endpoint: "https://auth.example.com/validate" # URL of the auth endpoint
``` ```
@ -416,8 +430,8 @@ clients.
server: server:
port: 8321 port: 8321
auth: auth:
provider_type: custom provider_config:
config: type: custom
endpoint: https://auth.example.com/validate endpoint: https://auth.example.com/validate
quota: quota:
kvstore: kvstore:

View file

@ -39,6 +39,13 @@ docker pull llama-stack/distribution-meta-reference-gpu
**Guides:** [Meta Reference GPU Guide](self_hosted_distro/meta-reference-gpu) **Guides:** [Meta Reference GPU Guide](self_hosted_distro/meta-reference-gpu)
### 🖥️ Self-Hosted with NVIDA NeMo Microservices
**Use `nvidia` if you:**
- Want to use Llama Stack with NVIDIA NeMo Microservices
**Guides:** [NVIDIA Distribution Guide](self_hosted_distro/nvidia)
### ☁️ Managed Hosting ### ☁️ Managed Hosting
**Use remote-hosted endpoints if you:** **Use remote-hosted endpoints if you:**

View file

@ -0,0 +1,177 @@
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# NVIDIA Distribution
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `inline::localfs`, `remote::nvidia` |
| eval | `remote::nvidia` |
| inference | `remote::nvidia` |
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |
| scoring | `inline::basic` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `inline::rag-runtime` |
| vector_io | `inline::faiss` |
### Environment Variables
The following environment variables can be configured:
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `NVIDIA_GUARDRAILS_CONFIG_ID`: NVIDIA Guardrail Configuration ID (default: `self-check`)
- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
### Models
The following models are available by default:
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `meta/llama-3.3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
- `nvidia/nv-embedqa-e5-v5 `
- `nvidia/nv-embedqa-mistral-7b-v2 `
- `snowflake/arctic-embed-l `
## Prerequisites
### NVIDIA API Keys
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
### Deploy NeMo Microservices Platform
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
## Supported Services
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
### Inference: NVIDIA NIM
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
### Datasetio API: NeMo Data Store
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
### Eval API: NeMo Evaluator
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
### Post-Training API: NeMo Customizer
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
### Safety API: NeMo Guardrails
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
## Deploying models
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
```sh
# URL to NeMo NIM Proxy service
export NEMO_URL="http://nemo.test"
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"name": "llama-3.2-1b-instruct",
"namespace": "meta",
"config": {
"model": "meta/llama-3.2-1b-instruct",
"nim_deployment": {
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
"image_tag": "1.8.3",
"pvc_size": "25Gi",
"gpu": 1,
"additional_envs": {
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
}
}
}
}'
```
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
You can also remove a deployed NIM to free up GPU resources, if needed.
```sh
export NEMO_URL="http://nemo.test"
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
```
## Running Llama Stack with NVIDIA
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
```
### Via Conda
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
llama stack build --template nvidia --image-type venv
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
## Example Notebooks
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.

View file

@ -12,6 +12,7 @@ Please refer to the remote provider documentation.
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | | | `db_path` | `<class 'str'>` | No | PydanticUndefined | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
## Sample Configuration ## Sample Configuration

View file

@ -403,15 +403,16 @@ def _run_stack_build_command_from_build_config(
if template_name: if template_name:
# copy run.yaml from template to build_dir instead of generating it again # copy run.yaml from template to build_dir instead of generating it again
template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml" template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
with importlib.resources.as_file(template_path) as path:
run_config_file = build_dir / f"{template_name}-run.yaml" run_config_file = build_dir / f"{template_name}-run.yaml"
with importlib.resources.as_file(template_path) as path:
shutil.copy(path, run_config_file) shutil.copy(path, run_config_file)
cprint("Build Successful!", color="green", file=sys.stderr) cprint("Build Successful!", color="green", file=sys.stderr)
cprint(f"You can find the newly-built template here: {template_path}", color="blue", file=sys.stderr) cprint(f"You can find the newly-built template here: {run_config_file}", color="blue", file=sys.stderr)
cprint( cprint(
"You can run the new Llama Stack distro via: " "You can run the new Llama Stack distro via: "
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "blue"), + colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
color="green", color="green",
file=sys.stderr, file=sys.stderr,
) )

View file

@ -155,7 +155,11 @@ class StackRun(Subcommand):
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>> # func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
if callable(getattr(args, arg)): if callable(getattr(args, arg)):
continue continue
if arg == "config" and template_name: if arg == "config":
if template_name:
server_args.template = str(template_name)
else:
# Set the config file path
server_args.config = str(config_file) server_args.config = str(config_file)
else: else:
setattr(server_args, arg, getattr(args, arg)) setattr(server_args, arg, getattr(args, arg))
@ -168,6 +172,9 @@ class StackRun(Subcommand):
run_args.extend([str(args.port)]) run_args.extend([str(args.port)])
if config_file: if config_file:
if template_name:
run_args.extend(["--template", str(template_name)])
else:
run_args.extend(["--config", str(config_file)]) run_args.extend(["--config", str(config_file)])
if args.env: if args.env:

View file

@ -6,9 +6,9 @@
from enum import StrEnum from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Annotated, Any from typing import Annotated, Any, Literal, Self
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
@ -161,23 +161,113 @@ class LoggingConfig(BaseModel):
) )
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class AuthProviderType(StrEnum): class AuthProviderType(StrEnum):
"""Supported authentication provider types.""" """Supported authentication provider types."""
OAUTH2_TOKEN = "oauth2_token" OAUTH2_TOKEN = "oauth2_token"
GITHUB_TOKEN = "github_token"
CUSTOM = "custom" CUSTOM = "custom"
class OAuth2TokenAuthConfig(BaseModel):
"""Configuration for OAuth2 token authentication."""
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
audience: str = Field(default="llama-stack")
verify_tls: bool = Field(default=True)
tls_cafile: Path | None = Field(default=None)
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
introspection: OAuth2IntrospectionConfig | None = Field(
default=None, description="OAuth2 introspection configuration"
)
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class CustomAuthConfig(BaseModel):
"""Configuration for custom authentication."""
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
endpoint: str = Field(
...,
description="Custom authentication endpoint URL",
)
class GitHubTokenAuthConfig(BaseModel):
"""Configuration for GitHub token authentication."""
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
github_api_base_url: str = Field(
default="https://api.github.com",
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
)
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"login": "roles",
"organizations": "teams",
},
description="Mapping from GitHub user fields to access attributes",
)
AuthProviderConfig = Annotated[
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
Field(discriminator="type"),
]
class AuthenticationConfig(BaseModel): class AuthenticationConfig(BaseModel):
provider_type: AuthProviderType = Field( """Top-level authentication configuration."""
provider_config: AuthProviderConfig = Field(
..., ...,
description="Type of authentication provider", description="Authentication provider configuration",
) )
config: dict[str, Any] = Field( access_policy: list[AccessRule] = Field(
..., default=[],
description="Provider-specific configuration", description="Rules for determining access to resources",
) )
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
class AuthenticationRequiredError(Exception): class AuthenticationRequiredError(Exception):

View file

@ -87,8 +87,12 @@ class AuthenticationMiddleware:
headers = dict(scope.get("headers", [])) headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode() auth_header = headers.get(b"authorization", b"").decode()
if not auth_header or not auth_header.startswith("Bearer "): if not auth_header:
return await self._send_auth_error(send, "Missing or invalid Authorization header") error_msg = self.auth_provider.get_auth_error_message(scope)
return await self._send_auth_error(send, error_msg)
if not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Invalid Authorization header format")
token = auth_header.split("Bearer ", 1)[1] token = auth_header.split("Bearer ", 1)[1]

View file

@ -8,15 +8,19 @@ import ssl
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import Lock from asyncio import Lock
from pathlib import Path from urllib.parse import parse_qs, urlparse
from typing import Self
from urllib.parse import parse_qs
import httpx import httpx
from jose import jwt from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User from llama_stack.distribution.datatypes import (
AuthenticationConfig,
CustomAuthConfig,
GitHubTokenAuthConfig,
OAuth2TokenAuthConfig,
User,
)
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="auth")
@ -38,9 +42,7 @@ class AuthRequestContext(BaseModel):
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)") headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: dict[str, list[str]] = Field( params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel): class AuthRequest(BaseModel):
@ -62,6 +64,10 @@ class AuthProvider(ABC):
"""Clean up any resources.""" """Clean up any resources."""
pass pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return provider-specific authentication error message."""
return "Authentication required"
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]: def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {} attributes: dict[str, list[str]] = {}
@ -81,56 +87,6 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
return attributes return attributes
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class OAuth2TokenAuthProviderConfig(BaseModel):
audience: str = "llama-stack"
verify_tls: bool = True
tls_cafile: Path | None = None
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None
introspection: OAuth2IntrospectionConfig | None = None
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class OAuth2TokenAuthProvider(AuthProvider): class OAuth2TokenAuthProvider(AuthProvider):
""" """
JWT token authentication provider that validates a JWT token and extracts access attributes. JWT token authentication provider that validates a JWT token and extracts access attributes.
@ -138,7 +94,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
This should be the standard authentication provider for most use cases. This should be the standard authentication provider for most use cases.
""" """
def __init__(self, config: OAuth2TokenAuthProviderConfig): def __init__(self, config: OAuth2TokenAuthConfig):
self.config = config self.config = config
self._jwks_at: float = 0.0 self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {} self._jwks: dict[str, str] = {}
@ -170,7 +126,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
issuer=self.config.issuer, issuer=self.config.issuer,
) )
except Exception as exc: except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc raise ValueError("Invalid JWT token") from exc
# There are other standard claims, the most relevant of which is `scope`. # There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes. # We should incorporate these into the access attributes.
@ -232,6 +188,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
async def close(self): async def close(self):
pass pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return OAuth2-specific authentication error message."""
if self.config.issuer:
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
elif self.config.introspection:
# Extract domain from introspection URL for a cleaner message
domain = urlparse(self.config.introspection.url).netloc
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None: async def _refresh_jwks(self) -> None:
""" """
Refresh the JWKS cache. Refresh the JWKS cache.
@ -264,14 +231,10 @@ class OAuth2TokenAuthProvider(AuthProvider):
self._jwks_at = time.time() self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel):
endpoint: str
class CustomAuthProvider(AuthProvider): class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint.""" """Custom authentication provider that uses an external endpoint."""
def __init__(self, config: CustomAuthProviderConfig): def __init__(self, config: CustomAuthConfig):
self.config = config self.config = config
self._client = None self._client = None
@ -317,7 +280,7 @@ class CustomAuthProvider(AuthProvider):
try: try:
response_data = response.json() response_data = response.json()
auth_response = AuthResponse(**response_data) auth_response = AuthResponse(**response_data)
return User(auth_response.principal, auth_response.attributes) return User(principal=auth_response.principal, attributes=auth_response.attributes)
except Exception as e: except Exception as e:
logger.exception("Error parsing authentication response") logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e raise ValueError("Invalid authentication response format") from e
@ -338,15 +301,88 @@ class CustomAuthProvider(AuthProvider):
await self._client.aclose() await self._client.aclose()
self._client = None self._client = None
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return custom auth provider-specific authentication error message."""
domain = urlparse(self.config.endpoint).netloc
if domain:
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
else:
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
class GitHubTokenAuthProvider(AuthProvider):
"""
GitHub token authentication provider that validates GitHub access tokens directly.
This provider accepts GitHub personal access tokens or OAuth tokens and verifies
them against the GitHub API to get user information.
"""
def __init__(self, config: GitHubTokenAuthConfig):
self.config = config
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a GitHub token by calling the GitHub API.
This validates tokens issued by GitHub (personal access tokens or OAuth tokens).
"""
try:
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
except httpx.HTTPStatusError as e:
logger.warning(f"GitHub token validation failed: {e}")
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
principal = user_info["user"]["login"]
github_data = {
"login": user_info["user"]["login"],
"id": str(user_info["user"]["id"]),
"organizations": user_info.get("organizations", []),
}
access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
async def close(self):
"""Clean up any resources."""
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return GitHub-specific authentication error message."""
return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict:
"""Fetch user info and organizations from GitHub API."""
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github.v3+json",
"User-Agent": "llama-stack",
}
async with httpx.AsyncClient() as client:
user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0)
user_response.raise_for_status()
user_data = user_response.json()
return {
"user": user_data,
}
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider.""" """Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower() provider_config = config.provider_config
if provider_type == "custom": if isinstance(provider_config, CustomAuthConfig):
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) return CustomAuthProvider(provider_config)
elif provider_type == "oauth2_token": elif isinstance(provider_config, OAuth2TokenAuthConfig):
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config)) return OAuth2TokenAuthProvider(provider_config)
elif isinstance(provider_config, GitHubTokenAuthConfig):
return GitHubTokenAuthProvider(provider_config)
else: else:
supported_providers = ", ".join([t.value for t in AuthProviderType]) raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")

View file

@ -33,7 +33,11 @@ from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.access_control.access_control import AccessDeniedError from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
@ -217,7 +221,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
# Get auth attributes from the request scope # Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {}) user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "") principal = request.scope.get("principal", "")
user = User(principal, user_attributes) user = User(principal=principal, attributes=user_attributes)
await log_request_pre_validation(request) await log_request_pre_validation(request)
@ -405,13 +409,13 @@ def main(args: argparse.Namespace | None = None):
args = parser.parse_args() args = parser.parse_args()
log_line = "" log_line = ""
if args.config: if hasattr(args, "config") and args.config:
# if the user provided a config file, use it, even if template was specified # if the user provided a config file, use it, even if template was specified
config_file = Path(args.config) config_file = Path(args.config)
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist") raise ValueError(f"Config file {config_file} does not exist")
log_line = f"Using config file: {config_file}" log_line = f"Using config file: {config_file}"
elif args.template: elif hasattr(args, "template") and args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist") raise ValueError(f"Template {args.template} does not exist")
@ -455,7 +459,7 @@ def main(args: argparse.Namespace | None = None):
# Add authentication middleware if configured # Add authentication middleware if configured
if config.server.auth: if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}") logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
else: else:
if config.server.quota: if config.server.quota:

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import ( from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig, KVStoreConfig,
@ -19,6 +19,7 @@ from llama_stack.schema_utils import json_schema_type
class MilvusVectorIOConfig(BaseModel): class MilvusVectorIOConfig(BaseModel):
db_path: str db_path: str
kvstore: KVStoreConfig kvstore: KVStoreConfig
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:

View file

@ -154,10 +154,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs: for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.mdel_validate_json(vector_db_data) vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db, vector_db,
index=await MilvusIndex( index=MilvusIndex(
client=self.client, client=self.client,
collection_name=vector_db.identifier, collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level, consistency_level=self.config.consistency_level,
@ -259,6 +259,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
assert self.kvstore is not None assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key) await self.kvstore.delete(key)
if store_id in self.openai_vector_stores:
del self.openai_vector_stores[store_id]
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage.""" """Load all vector store metadata from persistent storage."""
@ -377,6 +379,29 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
logger.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}") logger.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
return {} return {}
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
return
file_data = [
{
"store_file_id": f"{store_id}_{file_id}",
"store_id": store_id,
"file_id": file_id,
"file_info": json.dumps(file_info),
}
]
await asyncio.to_thread(
self.client.upsert,
collection_name="openai_vector_store_files",
data=file_data,
)
except Exception as e:
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
raise
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from Milvus database.""" """Load vector store file contents from Milvus database."""
try: try:
@ -405,29 +430,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
logger.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}") logger.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
return [] return []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
return
file_data = [
{
"store_file_id": f"{store_id}_{file_id}",
"store_id": store_id,
"file_id": file_id,
"file_info": json.dumps(file_info),
}
]
await asyncio.to_thread(
self.client.upsert,
collection_name="openai_vector_store_files",
data=file_data,
)
except Exception as e:
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
raise
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from Milvus database.""" """Delete vector store file metadata from Milvus database."""
try: try:

View file

@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore") logger = get_logger(name=__name__, category="authorized_sqlstore")
@ -71,9 +72,18 @@ class AuthorizedSqlStore:
:param sql_store: Base SqlStore implementation to wrap :param sql_store: Base SqlStore implementation to wrap
""" """
self.sql_store = sql_store self.sql_store = sql_store
self._detect_database_type()
self._validate_sql_optimized_policy() self._validate_sql_optimized_policy()
def _detect_database_type(self) -> None:
"""Detect the database type from the underlying SQL store."""
if not hasattr(self.sql_store, "config"):
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
self.database_type = self.sql_store.config.type
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _validate_sql_optimized_policy(self) -> None: def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy(). """Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
@ -181,6 +191,50 @@ class AuthorizedSqlStore:
else: else:
return self._build_conservative_where_clause() return self._build_conservative_where_clause()
def _json_extract(self, column: str, path: str) -> str:
"""Extract JSON value (keeping JSON type).
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _json_extract_text(self, column: str, path: str) -> str:
"""Extract JSON value as text.
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value as text
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->>'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
if self.database_type == SqlStoreType.postgres:
# Postgres stores JSON null as 'null'
return ["access_attributes::text = 'null'"]
elif self.database_type == SqlStoreType.sqlite:
return ["access_attributes = 'null'"]
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _build_default_policy_where_clause(self) -> str: def _build_default_policy_where_clause(self) -> str:
"""Build SQL WHERE clause for the default policy. """Build SQL WHERE clause for the default policy.
@ -189,29 +243,32 @@ class AuthorizedSqlStore:
""" """
current_user = get_authenticated_user() current_user = get_authenticated_user()
base_conditions = self._get_public_access_conditions()
if not current_user or not current_user.attributes: if not current_user or not current_user.attributes:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')" # Only allow public records
return f"({' OR '.join(base_conditions)})"
else: else:
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
user_attr_conditions = [] user_attr_conditions = []
for attr_key, user_values in current_user.attributes.items(): for attr_key, user_values in current_user.attributes.items():
if user_values: if user_values:
value_conditions = [] value_conditions = []
for value in user_values: for value in user_values:
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'") # Check if JSON array contains the value
escaped_value = value.replace("'", "''")
json_text = self._json_extract_text("access_attributes", attr_key)
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
if value_conditions: if value_conditions:
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL" # Check if the category is missing (NULL)
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})" user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})") user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
if user_attr_conditions: if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})" all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met) base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
else:
return f"({' OR '.join(base_conditions)})" return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str: def _build_conservative_where_clause(self) -> str:
@ -222,5 +279,8 @@ class AuthorizedSqlStore:
current_user = get_authenticated_user() current_user = get_authenticated_user()
if not current_user: if not current_user:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')" # Only allow public records
base_conditions = self._get_public_access_conditions()
return f"({' OR '.join(base_conditions)})"
return "1=1" return "1=1"

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Annotated, Literal from typing import Annotated, Literal
@ -19,7 +18,7 @@ from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"] sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
class SqlStoreType(Enum): class SqlStoreType(StrEnum):
sqlite = "sqlite" sqlite = "sqlite"
postgres = "postgres" postgres = "postgres"
@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel):
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig): class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["sqlite"] = SqlStoreType.sqlite.value type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
db_path: str = Field( db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db", description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig): class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["postgres"] = SqlStoreType.postgres.value type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
host: str = "localhost" host: str = "localhost"
port: int = 5432 port: int = 5432
db: str = "llamastack" db: str = "llamastack"
@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore: def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]: if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqlAlchemySqlStoreImpl(config) impl = SqlAlchemySqlStoreImpl(config)

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .nvidia import get_distribution_template # noqa: F401

View file

@ -0,0 +1,29 @@
version: 2
distribution_spec:
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
providers:
inference:
- remote::nvidia
vector_io:
- inline::faiss
safety:
- remote::nvidia
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- remote::nvidia
post_training:
- remote::nvidia
datasetio:
- inline::localfs
- remote::nvidia
scoring:
- inline::basic
tool_runtime:
- inline::rag-runtime
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -0,0 +1,149 @@
# NVIDIA Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} {{ model.doc_string }}`
{% endfor %}
{% endif %}
## Prerequisites
### NVIDIA API Keys
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
### Deploy NeMo Microservices Platform
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
## Supported Services
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
### Inference: NVIDIA NIM
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
### Datasetio API: NeMo Data Store
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
### Eval API: NeMo Evaluator
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
### Post-Training API: NeMo Customizer
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
### Safety API: NeMo Guardrails
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
## Deploying models
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
```sh
# URL to NeMo NIM Proxy service
export NEMO_URL="http://nemo.test"
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"name": "llama-3.2-1b-instruct",
"namespace": "meta",
"config": {
"model": "meta/llama-3.2-1b-instruct",
"nim_deployment": {
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
"image_tag": "1.8.3",
"pvc_size": "25Gi",
"gpu": 1,
"additional_envs": {
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
}
}
}
}'
```
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
You can also remove a deployed NIM to free up GPU resources, if needed.
```sh
export NEMO_URL="http://nemo.test"
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
```
## Running Llama Stack with NVIDIA
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
```
### Via Conda
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
llama stack build --template nvidia --image-type venv
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
## Example Notebooks
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.

View file

@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::nvidia"],
"vector_io": ["inline::faiss"],
"safety": ["remote::nvidia"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["remote::nvidia"],
"post_training": ["remote::nvidia"],
"datasetio": ["inline::localfs", "remote::nvidia"],
"scoring": ["inline::basic"],
"tool_runtime": ["inline::rag-runtime"],
}
inference_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIAConfig.sample_run_config(),
)
safety_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig.sample_run_config(),
)
datasetio_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NvidiaDatasetIOConfig.sample_run_config(),
)
eval_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIAEvalConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="nvidia",
)
available_models = {
"nvidia": MODEL_ENTRIES,
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
default_models = get_model_registry(available_models)
return DistributionTemplate(
name="nvidia",
distro_type="self_hosted",
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"datasetio": [datasetio_provider],
"eval": [eval_provider],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
safety_provider,
],
"eval": [eval_provider],
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"NVIDIA_API_KEY": (
"",
"NVIDIA API Key",
),
"NVIDIA_APPEND_API_VERSION": (
"True",
"Whether to append the API version to the base_url",
),
## Nemo Customizer related variables
"NVIDIA_DATASET_NAMESPACE": (
"default",
"NVIDIA Dataset Namespace",
),
"NVIDIA_PROJECT_ID": (
"test-project",
"NVIDIA Project ID",
),
"NVIDIA_CUSTOMIZER_URL": (
"https://customizer.api.nvidia.com",
"NVIDIA Customizer URL",
),
"NVIDIA_OUTPUT_MODEL_DIR": (
"test-example-model@v1",
"NVIDIA Output Model Directory",
),
"GUARDRAILS_SERVICE_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",
),
"NVIDIA_GUARDRAILS_CONFIG_ID": (
"self-check",
"NVIDIA Guardrail Configuration ID",
),
"NVIDIA_EVALUATOR_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Evaluator Service",
),
"INFERENCE_MODEL": (
"Llama3.1-8B-Instruct",
"Inference model",
),
"SAFETY_MODEL": (
"meta/llama-3.1-8b-instruct",
"Name of the model to use for safety",
),
},
)

View file

@ -0,0 +1,119 @@
version: 2
image_name: nvidia
apis:
- agents
- datasetio
- eval
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
- provider_id: nvidia
provider_type: remote::nvidia
config:
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}
post_training:
- provider_id: nvidia
provider_type: remote::nvidia
config:
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
datasetio:
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/localfs_datasetio.db
- provider_id: nvidia
provider_type: remote::nvidia
config:
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
datasets_url: ${env.NVIDIA_DATASETS_URL:=http://nemo.test}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: nvidia
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -0,0 +1,226 @@
version: 2
image_name: nvidia
apis:
- agents
- datasetio
- eval
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
- provider_id: nvidia
provider_type: remote::nvidia
config:
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}
post_training:
- provider_id: nvidia
provider_type: remote::nvidia
config:
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
datasetio:
- provider_id: nvidia
provider_type: remote::nvidia
config:
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
datasets_url: ${env.NVIDIA_DATASETS_URL:=http://nemo.test}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
models:
- metadata: {}
model_id: meta/llama3-8b-instruct
provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-8B-Instruct
provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama3-70b-instruct
provider_id: nvidia
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-8b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-70b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-405b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: nvidia
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-1b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-3b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-11b-vision-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-90b-vision-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.3-70b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata:
embedding_dimension: 2048
context_length: 8192
model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
provider_id: nvidia
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: nvidia/nv-embedqa-e5-v5
provider_id: nvidia
provider_model_id: nvidia/nv-embedqa-e5-v5
model_type: embedding
- metadata:
embedding_dimension: 4096
context_length: 512
model_id: nvidia/nv-embedqa-mistral-7b-v2
provider_id: nvidia
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: snowflake/arctic-embed-l
provider_id: nvidia
provider_model_id: snowflake/arctic-embed-l
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -0,0 +1,6 @@
import NextAuth from "next-auth";
import { authOptions } from "@/lib/auth";
const handler = NextAuth(authOptions);
export { handler as GET, handler as POST };

View file

@ -0,0 +1,118 @@
"use client";
import { signIn, signOut, useSession } from "next-auth/react";
import { Button } from "@/components/ui/button";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import { Copy, Check, Home, Github } from "lucide-react";
import { useState } from "react";
import { useRouter } from "next/navigation";
export default function SignInPage() {
const { data: session, status } = useSession();
const [copied, setCopied] = useState(false);
const router = useRouter();
const handleCopyToken = async () => {
if (session?.accessToken) {
await navigator.clipboard.writeText(session.accessToken);
setCopied(true);
setTimeout(() => setCopied(false), 2000);
}
};
if (status === "loading") {
return (
<div className="flex items-center justify-center min-h-screen">
<div className="text-muted-foreground">Loading...</div>
</div>
);
}
return (
<div className="flex items-center justify-center min-h-screen">
<Card className="w-[400px]">
<CardHeader>
<CardTitle>Authentication</CardTitle>
<CardDescription>
{session
? "You are successfully authenticated!"
: "Sign in with GitHub to use your access token as an API key"}
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
{!session ? (
<Button
onClick={() => {
console.log("Signing in with GitHub...");
signIn("github", { callbackUrl: "/auth/signin" }).catch(
(error) => {
console.error("Sign in error:", error);
},
);
}}
className="w-full"
variant="default"
>
<Github className="mr-2 h-4 w-4" />
Sign in with GitHub
</Button>
) : (
<div className="space-y-4">
<div className="text-sm text-muted-foreground">
Signed in as {session.user?.email}
</div>
{session.accessToken && (
<div className="space-y-2">
<div className="text-sm font-medium">
GitHub Access Token:
</div>
<div className="flex gap-2">
<code className="flex-1 p-2 bg-muted rounded text-xs break-all">
{session.accessToken}
</code>
<Button
size="sm"
variant="outline"
onClick={handleCopyToken}
>
{copied ? (
<Check className="h-4 w-4" />
) : (
<Copy className="h-4 w-4" />
)}
</Button>
</div>
<div className="text-xs text-muted-foreground">
This GitHub token will be used as your API key for
authenticated Llama Stack requests.
</div>
</div>
)}
<div className="flex gap-2">
<Button onClick={() => router.push("/")} className="flex-1">
<Home className="mr-2 h-4 w-4" />
Go to Dashboard
</Button>
<Button
onClick={() => signOut()}
variant="outline"
className="flex-1"
>
Sign out
</Button>
</div>
</div>
)}
</CardContent>
</Card>
</div>
);
}

View file

@ -1,5 +1,6 @@
import type { Metadata } from "next"; import type { Metadata } from "next";
import { ThemeProvider } from "@/components/ui/theme-provider"; import { ThemeProvider } from "@/components/ui/theme-provider";
import { SessionProvider } from "@/components/providers/session-provider";
import { Geist, Geist_Mono } from "next/font/google"; import { Geist, Geist_Mono } from "next/font/google";
import { ModeToggle } from "@/components/ui/mode-toggle"; import { ModeToggle } from "@/components/ui/mode-toggle";
import "./globals.css"; import "./globals.css";
@ -21,11 +22,13 @@ export const metadata: Metadata = {
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { AppSidebar } from "@/components/layout/app-sidebar"; import { AppSidebar } from "@/components/layout/app-sidebar";
import { SignInButton } from "@/components/ui/sign-in-button";
export default function Layout({ children }: { children: React.ReactNode }) { export default function Layout({ children }: { children: React.ReactNode }) {
return ( return (
<html lang="en" suppressHydrationWarning> <html lang="en" suppressHydrationWarning>
<body className={`${geistSans.variable} ${geistMono.variable} font-sans`}> <body className={`${geistSans.variable} ${geistMono.variable} font-sans`}>
<SessionProvider>
<ThemeProvider <ThemeProvider
attribute="class" attribute="class"
defaultTheme="system" defaultTheme="system"
@ -41,7 +44,8 @@ export default function Layout({ children }: { children: React.ReactNode }) {
<SidebarTrigger /> <SidebarTrigger />
</div> </div>
<div className="flex-1 text-center"></div> <div className="flex-1 text-center"></div>
<div className="flex-none"> <div className="flex-none flex items-center gap-2">
<SignInButton />
<ModeToggle /> <ModeToggle />
</div> </div>
</div> </div>
@ -49,6 +53,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
</main> </main>
</SidebarProvider> </SidebarProvider>
</ThemeProvider> </ThemeProvider>
</SessionProvider>
</body> </body>
</html> </html>
); );

View file

@ -4,11 +4,12 @@ import { useEffect, useState } from "react";
import { useParams } from "next/navigation"; import { useParams } from "next/navigation";
import { ChatCompletion } from "@/lib/types"; import { ChatCompletion } from "@/lib/types";
import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail"; import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail";
import { client } from "@/lib/client"; import { useAuthClient } from "@/hooks/use-auth-client";
export default function ChatCompletionDetailPage() { export default function ChatCompletionDetailPage() {
const params = useParams(); const params = useParams();
const id = params.id as string; const id = params.id as string;
const client = useAuthClient();
const [completionDetail, setCompletionDetail] = const [completionDetail, setCompletionDetail] =
useState<ChatCompletion | null>(null); useState<ChatCompletion | null>(null);
@ -45,7 +46,7 @@ export default function ChatCompletionDetailPage() {
}; };
fetchCompletionDetail(); fetchCompletionDetail();
}, [id]); }, [id, client]);
return ( return (
<ChatCompletionDetailView <ChatCompletionDetailView

View file

@ -5,11 +5,12 @@ import { useParams } from "next/navigation";
import type { ResponseObject } from "llama-stack-client/resources/responses/responses"; import type { ResponseObject } from "llama-stack-client/resources/responses/responses";
import { OpenAIResponse, InputItemListResponse } from "@/lib/types"; import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
import { ResponseDetailView } from "@/components/responses/responses-detail"; import { ResponseDetailView } from "@/components/responses/responses-detail";
import { client } from "@/lib/client"; import { useAuthClient } from "@/hooks/use-auth-client";
export default function ResponseDetailPage() { export default function ResponseDetailPage() {
const params = useParams(); const params = useParams();
const id = params.id as string; const id = params.id as string;
const client = useAuthClient();
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>( const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
null, null,
@ -109,7 +110,7 @@ export default function ResponseDetailPage() {
}; };
fetchResponseDetail(); fetchResponseDetail();
}, [id]); }, [id, client]);
return ( return (
<ResponseDetailView <ResponseDetailView

View file

@ -12,24 +12,34 @@ jest.mock("next/navigation", () => ({
}), }),
})); }));
// Mock next-auth
jest.mock("next-auth/react", () => ({
useSession: () => ({
status: "authenticated",
data: { accessToken: "mock-token" },
}),
}));
// Mock helper functions // Mock helper functions
jest.mock("@/lib/truncate-text"); jest.mock("@/lib/truncate-text");
jest.mock("@/lib/format-message-content"); jest.mock("@/lib/format-message-content");
// Mock the client // Mock the auth client hook
jest.mock("@/lib/client", () => ({ const mockClient = {
client: {
chat: { chat: {
completions: { completions: {
list: jest.fn(), list: jest.fn(),
}, },
}, },
}, };
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: () => mockClient,
})); }));
// Mock the usePagination hook // Mock the usePagination hook
const mockLoadMore = jest.fn(); const mockLoadMore = jest.fn();
jest.mock("@/hooks/usePagination", () => ({ jest.mock("@/hooks/use-pagination", () => ({
usePagination: jest.fn(() => ({ usePagination: jest.fn(() => ({
data: [], data: [],
status: "idle", status: "idle",
@ -47,7 +57,7 @@ import {
} from "@/lib/format-message-content"; } from "@/lib/format-message-content";
// Import the mocked hook // Import the mocked hook
import { usePagination } from "@/hooks/usePagination"; import { usePagination } from "@/hooks/use-pagination";
const mockedUsePagination = usePagination as jest.MockedFunction< const mockedUsePagination = usePagination as jest.MockedFunction<
typeof usePagination typeof usePagination
>; >;

View file

@ -10,8 +10,7 @@ import {
extractTextFromContentPart, extractTextFromContentPart,
extractDisplayableText, extractDisplayableText,
} from "@/lib/format-message-content"; } from "@/lib/format-message-content";
import { usePagination } from "@/hooks/usePagination"; import { usePagination } from "@/hooks/use-pagination";
import { client } from "@/lib/client";
interface ChatCompletionsTableProps { interface ChatCompletionsTableProps {
/** Optional pagination configuration */ /** Optional pagination configuration */
@ -32,12 +31,15 @@ function formatChatCompletionToRow(completion: ChatCompletion): LogTableRow {
export function ChatCompletionsTable({ export function ChatCompletionsTable({
paginationOptions, paginationOptions,
}: ChatCompletionsTableProps) { }: ChatCompletionsTableProps) {
const fetchFunction = async (params: { const fetchFunction = async (
client: ReturnType<typeof import("@/hooks/use-auth-client").useAuthClient>,
params: {
after?: string; after?: string;
limit: number; limit: number;
model?: string; model?: string;
order?: string; order?: string;
}) => { },
) => {
const response = await client.chat.completions.list({ const response = await client.chat.completions.list({
after: params.after, after: params.after,
limit: params.limit, limit: params.limit,

View file

@ -12,7 +12,7 @@ jest.mock("next/navigation", () => ({
})); }));
// Mock the useInfiniteScroll hook // Mock the useInfiniteScroll hook
jest.mock("@/hooks/useInfiniteScroll", () => ({ jest.mock("@/hooks/use-infinite-scroll", () => ({
useInfiniteScroll: jest.fn((onLoadMore, options) => { useInfiniteScroll: jest.fn((onLoadMore, options) => {
const ref = React.useRef(null); const ref = React.useRef(null);

View file

@ -4,7 +4,7 @@ import { useRouter } from "next/navigation";
import { useRef } from "react"; import { useRef } from "react";
import { truncateText } from "@/lib/truncate-text"; import { truncateText } from "@/lib/truncate-text";
import { PaginationStatus } from "@/lib/types"; import { PaginationStatus } from "@/lib/types";
import { useInfiniteScroll } from "@/hooks/useInfiniteScroll"; import { useInfiniteScroll } from "@/hooks/use-infinite-scroll";
import { import {
Table, Table,
TableBody, TableBody,

View file

@ -0,0 +1,7 @@
"use client";
import { SessionProvider as NextAuthSessionProvider } from "next-auth/react";
export function SessionProvider({ children }: { children: React.ReactNode }) {
return <NextAuthSessionProvider>{children}</NextAuthSessionProvider>;
}

View file

@ -12,21 +12,31 @@ jest.mock("next/navigation", () => ({
}), }),
})); }));
// Mock next-auth
jest.mock("next-auth/react", () => ({
useSession: () => ({
status: "authenticated",
data: { accessToken: "mock-token" },
}),
}));
// Mock helper functions // Mock helper functions
jest.mock("@/lib/truncate-text"); jest.mock("@/lib/truncate-text");
// Mock the client // Mock the auth client hook
jest.mock("@/lib/client", () => ({ const mockClient = {
client: {
responses: { responses: {
list: jest.fn(), list: jest.fn(),
}, },
}, };
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: () => mockClient,
})); }));
// Mock the usePagination hook // Mock the usePagination hook
const mockLoadMore = jest.fn(); const mockLoadMore = jest.fn();
jest.mock("@/hooks/usePagination", () => ({ jest.mock("@/hooks/use-pagination", () => ({
usePagination: jest.fn(() => ({ usePagination: jest.fn(() => ({
data: [], data: [],
status: "idle", status: "idle",
@ -40,7 +50,7 @@ jest.mock("@/hooks/usePagination", () => ({
import { truncateText as originalTruncateText } from "@/lib/truncate-text"; import { truncateText as originalTruncateText } from "@/lib/truncate-text";
// Import the mocked hook // Import the mocked hook
import { usePagination } from "@/hooks/usePagination"; import { usePagination } from "@/hooks/use-pagination";
const mockedUsePagination = usePagination as jest.MockedFunction< const mockedUsePagination = usePagination as jest.MockedFunction<
typeof usePagination typeof usePagination
>; >;

View file

@ -6,8 +6,7 @@ import {
UsePaginationOptions, UsePaginationOptions,
} from "@/lib/types"; } from "@/lib/types";
import { LogsTable, LogTableRow } from "@/components/logs/logs-table"; import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
import { usePagination } from "@/hooks/usePagination"; import { usePagination } from "@/hooks/use-pagination";
import { client } from "@/lib/client";
import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses"; import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses";
import { import {
isMessageInput, isMessageInput,
@ -125,12 +124,15 @@ function formatResponseToRow(response: OpenAIResponse): LogTableRow {
} }
export function ResponsesTable({ paginationOptions }: ResponsesTableProps) { export function ResponsesTable({ paginationOptions }: ResponsesTableProps) {
const fetchFunction = async (params: { const fetchFunction = async (
client: ReturnType<typeof import("@/hooks/use-auth-client").useAuthClient>,
params: {
after?: string; after?: string;
limit: number; limit: number;
model?: string; model?: string;
order?: string; order?: string;
}) => { },
) => {
const response = await client.responses.list({ const response = await client.responses.list({
after: params.after, after: params.after,
limit: params.limit, limit: params.limit,

View file

@ -0,0 +1,25 @@
"use client";
import { User } from "lucide-react";
import Link from "next/link";
import { useSession } from "next-auth/react";
import { Button } from "./button";
export function SignInButton() {
const { data: session, status } = useSession();
return (
<Button variant="ghost" size="sm" asChild>
<Link href="/auth/signin" className="flex items-center">
<User className="mr-2 h-4 w-4" />
<span>
{status === "loading"
? "Loading..."
: session
? session.user?.email || "Signed In"
: "Sign In"}
</span>
</Link>
</Button>
);
}

View file

@ -0,0 +1,24 @@
import { useSession } from "next-auth/react";
import { useMemo } from "react";
import LlamaStackClient from "llama-stack-client";
export function useAuthClient() {
const { data: session } = useSession();
const client = useMemo(() => {
const clientHostname =
typeof window !== "undefined" ? window.location.origin : "";
const options: any = {
baseURL: `${clientHostname}/api`,
};
if (session?.accessToken) {
options.apiKey = session.accessToken;
}
return new LlamaStackClient(options);
}, [session?.accessToken]);
return client;
}

View file

@ -2,6 +2,9 @@
import { useState, useCallback, useEffect, useRef } from "react"; import { useState, useCallback, useEffect, useRef } from "react";
import { PaginationStatus, UsePaginationOptions } from "@/lib/types"; import { PaginationStatus, UsePaginationOptions } from "@/lib/types";
import { useSession } from "next-auth/react";
import { useAuthClient } from "@/hooks/use-auth-client";
import { useRouter } from "next/navigation";
interface PaginationState<T> { interface PaginationState<T> {
data: T[]; data: T[];
@ -28,13 +31,18 @@ export interface PaginationReturn<T> {
} }
interface UsePaginationParams<T> extends UsePaginationOptions { interface UsePaginationParams<T> extends UsePaginationOptions {
fetchFunction: (params: { fetchFunction: (
client: ReturnType<typeof useAuthClient>,
params: {
after?: string; after?: string;
limit: number; limit: number;
model?: string; model?: string;
order?: string; order?: string;
}) => Promise<PaginationResponse<T>>; },
) => Promise<PaginationResponse<T>>;
errorMessagePrefix: string; errorMessagePrefix: string;
enabled?: boolean;
useAuth?: boolean;
} }
export function usePagination<T>({ export function usePagination<T>({
@ -43,7 +51,12 @@ export function usePagination<T>({
order = "desc", order = "desc",
fetchFunction, fetchFunction,
errorMessagePrefix, errorMessagePrefix,
enabled = true,
useAuth = true,
}: UsePaginationParams<T>): PaginationReturn<T> { }: UsePaginationParams<T>): PaginationReturn<T> {
const { status: sessionStatus } = useSession();
const client = useAuthClient();
const router = useRouter();
const [state, setState] = useState<PaginationState<T>>({ const [state, setState] = useState<PaginationState<T>>({
data: [], data: [],
status: "loading", status: "loading",
@ -74,7 +87,7 @@ export function usePagination<T>({
error: null, error: null,
})); }));
const response = await fetchFunction({ const response = await fetchFunction(client, {
after: after || undefined, after: after || undefined,
limit: fetchLimit, limit: fetchLimit,
...(model && { model }), ...(model && { model }),
@ -91,6 +104,17 @@ export function usePagination<T>({
status: "idle", status: "idle",
})); }));
} catch (err) { } catch (err) {
// Check if it's a 401 unauthorized error
if (
err &&
typeof err === "object" &&
"status" in err &&
err.status === 401
) {
router.push("/auth/signin");
return;
}
const errorMessage = isInitialLoad const errorMessage = isInitialLoad
? `Failed to load ${errorMessagePrefix}. Please try refreshing the page.` ? `Failed to load ${errorMessagePrefix}. Please try refreshing the page.`
: `Failed to load more ${errorMessagePrefix}. Please try again.`; : `Failed to load more ${errorMessagePrefix}. Please try again.`;
@ -107,7 +131,7 @@ export function usePagination<T>({
})); }));
} }
}, },
[limit, model, order, fetchFunction, errorMessagePrefix], [limit, model, order, fetchFunction, errorMessagePrefix, client, router],
); );
/** /**
@ -120,17 +144,28 @@ export function usePagination<T>({
} }
}, [fetchData]); }, [fetchData]);
// Auto-load initial data on mount // Auto-load initial data on mount when enabled
useEffect(() => { useEffect(() => {
if (!hasFetchedInitialData.current) { // If using auth, wait for session to load
const isAuthReady = !useAuth || sessionStatus !== "loading";
const shouldFetch = enabled && isAuthReady;
if (shouldFetch && !hasFetchedInitialData.current) {
hasFetchedInitialData.current = true; hasFetchedInitialData.current = true;
fetchData(); fetchData();
} else if (!shouldFetch) {
// Reset the flag when disabled so it can fetch when re-enabled
hasFetchedInitialData.current = false;
} }
}, [fetchData]); }, [fetchData, enabled, useAuth, sessionStatus]);
// Override status if we're waiting for auth
const effectiveStatus =
useAuth && sessionStatus === "loading" ? "loading" : state.status;
return { return {
data: state.data, data: state.data,
status: state.status, status: effectiveStatus,
hasMore: state.hasMore, hasMore: state.hasMore,
error: state.error, error: state.error,
loadMore, loadMore,

View file

@ -0,0 +1,11 @@
/**
* Next.js Instrumentation
* This file is used for initializing monitoring, tracing, or other observability tools.
* It runs once when the server starts, before any application code.
*
* Learn more: https://nextjs.org/docs/app/building-your-application/optimizing/instrumentation
*/
export async function register() {
await import("./lib/config-validator");
}

View file

@ -0,0 +1,38 @@
import { NextAuthOptions } from "next-auth";
import GithubProvider from "next-auth/providers/github";
export const authOptions: NextAuthOptions = {
providers: [
GithubProvider({
clientId: process.env.GITHUB_CLIENT_ID!,
clientSecret: process.env.GITHUB_CLIENT_SECRET!,
authorization: {
params: {
scope: "read:user user:email",
},
},
}),
],
debug: process.env.NODE_ENV === "development",
callbacks: {
async jwt({ token, account }) {
// Persist the OAuth access_token to the token right after signin
if (account) {
token.accessToken = account.access_token;
}
return token;
},
async session({ session, token }) {
// Send properties to the client, like an access_token from a provider.
session.accessToken = token.accessToken as string;
return session;
},
},
pages: {
signIn: "/auth/signin",
error: "/auth/signin", // Redirect errors to our custom page
},
session: {
strategy: "jwt",
},
};

View file

@ -1,6 +0,0 @@
import LlamaStackClient from "llama-stack-client";
export const client = new LlamaStackClient({
baseURL:
typeof window !== "undefined" ? `${window.location.origin}/api` : "/api",
});

View file

@ -0,0 +1,56 @@
/**
* Validates environment configuration for the application
* This is called during server initialization
*/
export function validateServerConfig() {
if (process.env.NODE_ENV === "development") {
console.log("🚀 Starting Llama Stack UI Server...");
// Check optional configurations
const optionalConfigs = {
NEXTAUTH_URL: process.env.NEXTAUTH_URL || "http://localhost:8322",
LLAMA_STACK_BACKEND_URL:
process.env.LLAMA_STACK_BACKEND_URL || "http://localhost:8321",
LLAMA_STACK_UI_PORT: process.env.LLAMA_STACK_UI_PORT || "8322",
GITHUB_CLIENT_ID: process.env.GITHUB_CLIENT_ID,
GITHUB_CLIENT_SECRET: process.env.GITHUB_CLIENT_SECRET,
};
console.log("\n📋 Configuration:");
console.log(` - NextAuth URL: ${optionalConfigs.NEXTAUTH_URL}`);
console.log(` - Backend URL: ${optionalConfigs.LLAMA_STACK_BACKEND_URL}`);
console.log(` - UI Port: ${optionalConfigs.LLAMA_STACK_UI_PORT}`);
// Check GitHub OAuth configuration
if (
!optionalConfigs.GITHUB_CLIENT_ID ||
!optionalConfigs.GITHUB_CLIENT_SECRET
) {
console.log(
"\n📝 GitHub OAuth not configured (authentication features disabled)",
);
console.log(" To enable GitHub OAuth:");
console.log(" 1. Go to https://github.com/settings/applications/new");
console.log(
" 2. Set Application name: Llama Stack UI (or your preferred name)",
);
console.log(" 3. Set Homepage URL: http://localhost:8322");
console.log(
" 4. Set Authorization callback URL: http://localhost:8322/api/auth/callback/github",
);
console.log(
" 5. Create the app and copy the Client ID and Client Secret",
);
console.log(" 6. Add them to your .env.local file:");
console.log(" GITHUB_CLIENT_ID=your_client_id");
console.log(" GITHUB_CLIENT_SECRET=your_client_secret");
} else {
console.log(" - GitHub OAuth: ✅ Configured");
}
console.log("");
}
}
// Call this function when the module is imported
validateServerConfig();

View file

@ -18,6 +18,7 @@
"llama-stack-client": "0.2.13", "llama-stack-client": "0.2.13",
"lucide-react": "^0.510.0", "lucide-react": "^0.510.0",
"next": "15.3.3", "next": "15.3.3",
"next-auth": "^4.24.11",
"next-themes": "^0.4.6", "next-themes": "^0.4.6",
"react": "^19.0.0", "react": "^19.0.0",
"react-dom": "^19.0.0", "react-dom": "^19.0.0",
@ -548,7 +549,6 @@
"version": "7.27.1", "version": "7.27.1",
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.1.tgz", "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.1.tgz",
"integrity": "sha512-1x3D2xEk2fRo3PAhwQwu5UubzgiVWSXTBfWpVd2Mx2AzRqJuDJCsgaDVZ7HB5iGzDW1Hl1sWN2mFyKjmR9uAog==", "integrity": "sha512-1x3D2xEk2fRo3PAhwQwu5UubzgiVWSXTBfWpVd2Mx2AzRqJuDJCsgaDVZ7HB5iGzDW1Hl1sWN2mFyKjmR9uAog==",
"dev": true,
"license": "MIT", "license": "MIT",
"engines": { "engines": {
"node": ">=6.9.0" "node": ">=6.9.0"
@ -2423,6 +2423,15 @@
"node": ">=12.4.0" "node": ">=12.4.0"
} }
}, },
"node_modules/@panva/hkdf": {
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/@panva/hkdf/-/hkdf-1.2.1.tgz",
"integrity": "sha512-6oclG6Y3PiDFcoyk8srjLfVKyMfVCKJ27JwNPViuXziFpmdz+MZnZN/aKY0JGXgYuO/VghU0jcOAZgWXZ1Dmrw==",
"license": "MIT",
"funding": {
"url": "https://github.com/sponsors/panva"
}
},
"node_modules/@pkgr/core": { "node_modules/@pkgr/core": {
"version": "0.2.4", "version": "0.2.4",
"resolved": "https://registry.npmjs.org/@pkgr/core/-/core-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@pkgr/core/-/core-0.2.4.tgz",
@ -5279,7 +5288,6 @@
"version": "0.7.2", "version": "0.7.2",
"resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz",
"integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==",
"dev": true,
"license": "MIT", "license": "MIT",
"engines": { "engines": {
"node": ">= 0.6" "node": ">= 0.6"
@ -9036,6 +9044,15 @@
"jiti": "lib/jiti-cli.mjs" "jiti": "lib/jiti-cli.mjs"
} }
}, },
"node_modules/jose": {
"version": "4.15.9",
"resolved": "https://registry.npmjs.org/jose/-/jose-4.15.9.tgz",
"integrity": "sha512-1vUQX+IdDMVPj4k8kOxgUqlcK518yluMuGZwqlr44FS1ppZB/5GWh4rZG89erpOBOJjU/OBsnCVFfapsRz6nEA==",
"license": "MIT",
"funding": {
"url": "https://github.com/sponsors/panva"
}
},
"node_modules/js-tokens": { "node_modules/js-tokens": {
"version": "4.0.0", "version": "4.0.0",
"resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz",
@ -9949,6 +9966,38 @@
} }
} }
}, },
"node_modules/next-auth": {
"version": "4.24.11",
"resolved": "https://registry.npmjs.org/next-auth/-/next-auth-4.24.11.tgz",
"integrity": "sha512-pCFXzIDQX7xmHFs4KVH4luCjaCbuPRtZ9oBUjUhOk84mZ9WVPf94n87TxYI4rSRf9HmfHEF8Yep3JrYDVOo3Cw==",
"license": "ISC",
"dependencies": {
"@babel/runtime": "^7.20.13",
"@panva/hkdf": "^1.0.2",
"cookie": "^0.7.0",
"jose": "^4.15.5",
"oauth": "^0.9.15",
"openid-client": "^5.4.0",
"preact": "^10.6.3",
"preact-render-to-string": "^5.1.19",
"uuid": "^8.3.2"
},
"peerDependencies": {
"@auth/core": "0.34.2",
"next": "^12.2.5 || ^13 || ^14 || ^15",
"nodemailer": "^6.6.5",
"react": "^17.0.2 || ^18 || ^19",
"react-dom": "^17.0.2 || ^18 || ^19"
},
"peerDependenciesMeta": {
"@auth/core": {
"optional": true
},
"nodemailer": {
"optional": true
}
}
},
"node_modules/next-themes": { "node_modules/next-themes": {
"version": "0.4.6", "version": "0.4.6",
"resolved": "https://registry.npmjs.org/next-themes/-/next-themes-0.4.6.tgz", "resolved": "https://registry.npmjs.org/next-themes/-/next-themes-0.4.6.tgz",
@ -10071,6 +10120,12 @@
"dev": true, "dev": true,
"license": "MIT" "license": "MIT"
}, },
"node_modules/oauth": {
"version": "0.9.15",
"resolved": "https://registry.npmjs.org/oauth/-/oauth-0.9.15.tgz",
"integrity": "sha512-a5ERWK1kh38ExDEfoO6qUHJb32rd7aYmPHuyCu3Fta/cnICvYmgd2uhuKXvPD+PXB+gCEYYEaQdIRAjCOwAKNA==",
"license": "MIT"
},
"node_modules/object-assign": { "node_modules/object-assign": {
"version": "4.1.1", "version": "4.1.1",
"resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz",
@ -10081,6 +10136,15 @@
"node": ">=0.10.0" "node": ">=0.10.0"
} }
}, },
"node_modules/object-hash": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz",
"integrity": "sha512-gScRMn0bS5fH+IuwyIFgnh9zBdo4DV+6GhygmWM9HyNJSgS0hScp1f5vjtm7oIIOiT9trXrShAkLFSc2IqKNgw==",
"license": "MIT",
"engines": {
"node": ">= 6"
}
},
"node_modules/object-inspect": { "node_modules/object-inspect": {
"version": "1.13.4", "version": "1.13.4",
"resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz", "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz",
@ -10194,6 +10258,15 @@
"url": "https://github.com/sponsors/ljharb" "url": "https://github.com/sponsors/ljharb"
} }
}, },
"node_modules/oidc-token-hash": {
"version": "5.1.0",
"resolved": "https://registry.npmjs.org/oidc-token-hash/-/oidc-token-hash-5.1.0.tgz",
"integrity": "sha512-y0W+X7Ppo7oZX6eovsRkuzcSM40Bicg2JEJkDJ4irIt1wsYAP5MLSNv+QAogO8xivMffw/9OvV3um1pxXgt1uA==",
"license": "MIT",
"engines": {
"node": "^10.13.0 || >=12.0.0"
}
},
"node_modules/on-finished": { "node_modules/on-finished": {
"version": "2.4.1", "version": "2.4.1",
"resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz",
@ -10233,6 +10306,39 @@
"url": "https://github.com/sponsors/sindresorhus" "url": "https://github.com/sponsors/sindresorhus"
} }
}, },
"node_modules/openid-client": {
"version": "5.7.1",
"resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.7.1.tgz",
"integrity": "sha512-jDBPgSVfTnkIh71Hg9pRvtJc6wTwqjRkN88+gCFtYWrlP4Yx2Dsrow8uPi3qLr/aeymPF3o2+dS+wOpglK04ew==",
"license": "MIT",
"dependencies": {
"jose": "^4.15.9",
"lru-cache": "^6.0.0",
"object-hash": "^2.2.0",
"oidc-token-hash": "^5.0.3"
},
"funding": {
"url": "https://github.com/sponsors/panva"
}
},
"node_modules/openid-client/node_modules/lru-cache": {
"version": "6.0.0",
"resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz",
"integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==",
"license": "ISC",
"dependencies": {
"yallist": "^4.0.0"
},
"engines": {
"node": ">=10"
}
},
"node_modules/openid-client/node_modules/yallist": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz",
"integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==",
"license": "ISC"
},
"node_modules/optionator": { "node_modules/optionator": {
"version": "0.9.4", "version": "0.9.4",
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
@ -10560,6 +10666,34 @@
"node": "^10 || ^12 || >=14" "node": "^10 || ^12 || >=14"
} }
}, },
"node_modules/preact": {
"version": "10.26.9",
"resolved": "https://registry.npmjs.org/preact/-/preact-10.26.9.tgz",
"integrity": "sha512-SSjF9vcnF27mJK1XyFMNJzFd5u3pQiATFqoaDy03XuN00u4ziveVVEGt5RKJrDR8MHE/wJo9Nnad56RLzS2RMA==",
"license": "MIT",
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/preact"
}
},
"node_modules/preact-render-to-string": {
"version": "5.2.6",
"resolved": "https://registry.npmjs.org/preact-render-to-string/-/preact-render-to-string-5.2.6.tgz",
"integrity": "sha512-JyhErpYOvBV1hEPwIxc/fHWXPfnEGdRKxc8gFdAZ7XV4tlzyzG847XAyEZqoDnynP88akM4eaHcSOzNcLWFguw==",
"license": "MIT",
"dependencies": {
"pretty-format": "^3.8.0"
},
"peerDependencies": {
"preact": ">=10"
}
},
"node_modules/preact-render-to-string/node_modules/pretty-format": {
"version": "3.8.0",
"resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-3.8.0.tgz",
"integrity": "sha512-WuxUnVtlWL1OfZFQFuqvnvs6MiAGk9UNsBostyBOB0Is9wb5uRESevA6rnl/rkksXaGX3GzZhPup5d6Vp1nFew==",
"license": "MIT"
},
"node_modules/prelude-ls": { "node_modules/prelude-ls": {
"version": "1.2.1", "version": "1.2.1",
"resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz",
@ -12409,6 +12543,15 @@
} }
} }
}, },
"node_modules/uuid": {
"version": "8.3.2",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz",
"integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==",
"license": "MIT",
"bin": {
"uuid": "dist/bin/uuid"
}
},
"node_modules/v8-compile-cache-lib": { "node_modules/v8-compile-cache-lib": {
"version": "3.0.1", "version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",

View file

@ -23,6 +23,7 @@
"llama-stack-client": "0.2.13", "llama-stack-client": "0.2.13",
"lucide-react": "^0.510.0", "lucide-react": "^0.510.0",
"next": "15.3.3", "next": "15.3.3",
"next-auth": "^4.24.11",
"next-themes": "^0.4.6", "next-themes": "^0.4.6",
"react": "^19.0.0", "react": "^19.0.0",
"react-dom": "^19.0.0", "react-dom": "^19.0.0",

7
llama_stack/ui/types/next-auth.d.ts vendored Normal file
View file

@ -0,0 +1,7 @@
import NextAuth from "next-auth";
declare module "next-auth" {
interface Session {
accessToken?: string;
}
}

View file

@ -86,6 +86,7 @@ unit = [
"sqlalchemy[asyncio]>=2.0.41", "sqlalchemy[asyncio]>=2.0.41",
"blobfile", "blobfile",
"faiss-cpu", "faiss-cpu",
"pymilvus>=2.5.12",
] ]
# These are the core dependencies required for running integration tests. They are shared across all # These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment # providers. If a provider requires additional dependencies, please add them to your environment
@ -106,6 +107,7 @@ test = [
"sqlalchemy", "sqlalchemy",
"sqlalchemy[asyncio]>=2.0.41", "sqlalchemy[asyncio]>=2.0.41",
"requests", "requests",
"pymilvus>=2.5.12",
] ]
docs = [ docs = [
"setuptools", "setuptools",

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,173 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
from unittest.mock import patch
import pytest
from llama_stack.distribution.access_control.access_control import default_policy
from llama_stack.distribution.datatypes import User
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
def get_postgres_config():
"""Get PostgreSQL configuration if tests are enabled."""
return PostgresSqlStoreConfig(
host=os.environ.get("POSTGRES_HOST", "localhost"),
port=int(os.environ.get("POSTGRES_PORT", "5432")),
db=os.environ.get("POSTGRES_DB", "llamastack"),
user=os.environ.get("POSTGRES_USER", "llamastack"),
password=os.environ.get("POSTGRES_PASSWORD", "llamastack"),
)
def get_sqlite_config():
"""Get SQLite configuration with temporary database."""
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
tmp_file.close()
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name
@pytest.mark.asyncio
@pytest.mark.parametrize(
"backend_config",
[
pytest.param(
("postgres", get_postgres_config),
marks=pytest.mark.skipif(
not os.environ.get("ENABLE_POSTGRES_TESTS"),
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
),
id="postgres",
),
pytest.param(("sqlite", get_sqlite_config), id="sqlite"),
],
)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_json_comparison(mock_get_authenticated_user, backend_config):
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
backend_name, config_func = backend_config
# Handle different config types
if backend_name == "postgres":
config = config_func()
cleanup_path = None
else: # sqlite
config, cleanup_path = config_func()
try:
base_sqlstore = SqlAlchemySqlStoreImpl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore)
# Create test table
table_name = f"test_json_comparison_{backend_name}"
await authorized_store.create_table(
table=table_name,
schema={
"id": ColumnType.STRING,
"data": ColumnType.STRING,
},
)
try:
# Test with no authenticated user (should handle JSON null comparison)
mock_get_authenticated_user.return_value = None
# Insert some test data
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None
# Test with authenticated user
test_user = User("test-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = test_user
# Insert data with user attributes
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 2
# Test with non-admin user
regular_user = User("regular-user", {"roles": ["user"]})
mock_get_authenticated_user.return_value = regular_user
# Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
# Test the category missing branch: user with multiple attributes
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
mock_get_authenticated_user.return_value = multi_user
# Insert record with multi-user (has both roles and teams)
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
# Test different user types to create records with different attribute patterns
# Record with only roles (teams category will be missing)
roles_only_user = User("roles-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = roles_only_user
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
# Record with only teams (roles category will be missing)
teams_only_user = User("teams-user", {"teams": ["dev"]})
mock_get_authenticated_user.return_value = teams_only_user
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
# Record with different roles/teams (shouldn't match our test user)
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
mock_get_authenticated_user.return_value = different_user
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
# Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy())
# Should see:
# - public record (1) - no access_attributes
# - admin record (2) - user matches roles=admin, teams missing (allowed)
# - multi_user record (3) - user matches both roles=admin and teams=dev
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
# Should NOT see:
# - different_user record (6) - user doesn't match roles=user or teams=qa
expected_ids = {"1", "2", "3", "4", "5"}
actual_ids = {record["id"] for record in result.data}
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
# Verify the category missing logic specifically
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
assert category_test_ids == {"4", "5"}, (
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
)
finally:
# Clean up records
for record_id in ["1", "2", "3", "4", "5", "6"]:
try:
await base_sqlstore.delete(table_name, {"id": record_id})
except Exception:
pass
finally:
# Clean up temporary SQLite database file if needed
if cleanup_path:
try:
os.unlink(cleanup_path)
except OSError:
pass

View file

@ -0,0 +1,420 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import time
from unittest.mock import AsyncMock
import numpy as np
import pytest
import pytest_asyncio
from pymilvus import Collection, MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX, MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.utils.kvstore import kvstore_impl
# TODO: Refactor these to be for inline vector-io providers
MILVUS_ALIAS = "test_milvus"
COLLECTION_PREFIX = "test_collection"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture(scope="session")
def mock_inference_api(embedding_dimension):
class MockInferenceAPI:
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
return MockInferenceAPI()
@pytest_asyncio.fixture
async def unique_kvstore_config(tmp_path_factory):
# Generate a unique filename for this test
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
@pytest_asyncio.fixture(scope="session", autouse=True)
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_milvus.db")
client = MilvusClient(db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = db_path
yield index
@pytest_asyncio.fixture(scope="session")
async def milvus_vec_adapter(milvus_vec_index, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_index.db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.mark.asyncio
async def test_cache_contains_initial_collection(milvus_vec_adapter):
coll_name = milvus_vec_adapter.metadata_collection_name
assert coll_name in milvus_vec_adapter.cache
@pytest.mark.asyncio
async def test_add_chunks(milvus_vec_index, sample_chunks, sample_embeddings):
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
resp = await milvus_vec_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content
@pytest.mark.asyncio
async def test_query_chunks_vector(milvus_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_emb = np.random.rand(embedding_dimension).astype(np.float32)
resp = await milvus_vec_index.query_vector(query_emb, k=2, score_threshold=0.0)
assert isinstance(resp, QueryChunksResponse)
assert len(resp.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(milvus_vec_index, sample_chunks, embedding_dimension):
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await milvus_vec_index.add_chunks(sample_chunks, embeddings)
coll = Collection(milvus_vec_index.collection_name, using=MILVUS_ALIAS)
ids = coll.query(expr="id >= 0", output_fields=["id"], timeout=30)
flat_ids = [i["id"] for i in ids]
assert len(flat_ids) == len(set(flat_ids))
@pytest.mark.asyncio
async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_config):
kvstore = await kvstore_impl(unique_kvstore_config)
vector_db = VectorDB(
identifier="test_db",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
metadata={"test_key": "test_value"},
)
test_vector_db_data = vector_db.model_dump_json()
await kvstore.set(f"{VECTOR_DBS_PREFIX}test_db", test_vector_db_data)
tmp_milvus_vec_adapter = MilvusVectorIOAdapter(
config=MilvusVectorIOConfig(
db_path=milvus_vec_index.db_path,
kvstore=unique_kvstore_config,
),
inference_api=None,
files_api=None,
)
await tmp_milvus_vec_adapter.initialize()
vector_db = VectorDB(
identifier="test_db",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
test_vector_db_data = vector_db.model_dump_json()
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
assert milvus_vec_index.client is not None
assert isinstance(milvus_vec_index.client, MilvusClient)
assert tmp_milvus_vec_adapter.cache is not None
# registering a vector won't update the cache or openai_vector_store collection name
assert (
tmp_milvus_vec_adapter.metadata_collection_name not in tmp_milvus_vec_adapter.cache
or tmp_milvus_vec_adapter.openai_vector_stores
)
@pytest.mark.asyncio
async def test_persistence_across_adapter_restarts(
tmp_path, milvus_vec_index, mock_inference_api, unique_kvstore_config
):
adapter1 = MilvusVectorIOAdapter(
config=MilvusVectorIOConfig(db_path=milvus_vec_index.db_path, kvstore=unique_kvstore_config),
inference_api=mock_inference_api,
files_api=None,
)
await adapter1.initialize()
dummy = VectorDB(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await adapter1.register_vector_db(dummy)
await adapter1.shutdown()
await adapter1.initialize()
assert "foo_db" in adapter1.cache
await adapter1.shutdown()
@pytest.mark.asyncio
async def test_register_and_unregister_vector_db(milvus_vec_adapter):
try:
connections.disconnect(MILVUS_ALIAS)
except Exception as _:
pass
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_adapter.config.db_path)
unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB(
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await milvus_vec_adapter.register_vector_db(dummy)
assert dummy.identifier in milvus_vec_adapter.cache
if dummy.identifier in milvus_vec_adapter.cache:
index = milvus_vec_adapter.cache[dummy.identifier].index
if hasattr(index, "client") and hasattr(index.client, "_using"):
index.client._using = MILVUS_ALIAS
await milvus_vec_adapter.unregister_vector_db(dummy.identifier)
assert dummy.identifier not in milvus_vec_adapter.cache
@pytest.mark.asyncio
async def test_query_unregistered_raises(milvus_vec_adapter):
fake_emb = np.zeros(8, dtype=np.float32)
with pytest.raises(AttributeError):
await milvus_vec_adapter.query_chunks("no_such_db", fake_emb)
@pytest.mark.asyncio
async def test_insert_chunks_calls_underlying_index(milvus_vec_adapter):
fake_index = AsyncMock()
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
chunks = ["chunk1", "chunk2"]
await milvus_vec_adapter.insert_chunks("db1", chunks)
fake_index.insert_chunks.assert_awaited_once_with(chunks)
@pytest.mark.asyncio
async def test_insert_chunks_missing_db_raises(milvus_vec_adapter):
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await milvus_vec_adapter.insert_chunks("db_not_exist", [])
@pytest.mark.asyncio
async def test_query_chunks_calls_underlying_index_and_returns(milvus_vec_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
response = await milvus_vec_adapter.query_chunks("db1", "my_query", {"param": 1})
fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1})
assert response is expected
@pytest.mark.asyncio
async def test_query_chunks_missing_db_raises(milvus_vec_adapter):
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await milvus_vec_adapter.query_chunks("db_missing", "q", None)
@pytest.mark.asyncio
async def test_save_openai_vector_store(milvus_vec_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
assert openai_vector_store["id"] in milvus_vec_adapter.openai_vector_stores
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio
async def test_update_openai_vector_store(milvus_vec_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
openai_vector_store["description"] = "Updated description"
await milvus_vec_adapter._update_openai_vector_store(store_id, openai_vector_store)
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio
async def test_delete_openai_vector_store(milvus_vec_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
await milvus_vec_adapter._delete_openai_vector_store_from_storage(store_id)
assert openai_vector_store["id"] not in milvus_vec_adapter.openai_vector_stores
@pytest.mark.asyncio
async def test_load_openai_vector_stores(milvus_vec_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
loaded_stores = await milvus_vec_adapter._load_openai_vector_stores()
assert loaded_stores[store_id] == openai_vector_store
@pytest.mark.asyncio
async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
file_info = {
"id": file_id,
"status": "completed",
"vector_store_id": store_id,
"attributes": {},
"filename": "test_file.txt",
"created_at": int(time.time()),
}
file_contents = [
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
]
# validating we don't raise an exception
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
@pytest.mark.asyncio
async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
file_info = {
"id": file_id,
"status": "completed",
"vector_store_id": store_id,
"attributes": {},
"filename": "test_file.txt",
"created_at": int(time.time()),
}
file_contents = [
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
updated_file_info = file_info.copy()
updated_file_info["filename"] = "updated_test_file.txt"
await milvus_vec_adapter._update_openai_vector_store_file(
store_id,
file_id,
updated_file_info,
)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file(store_id, file_id)
assert loaded_contents == updated_file_info
assert loaded_contents != file_info
@pytest.mark.asyncio
async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
file_info = {
"id": file_id,
"status": "completed",
"vector_store_id": store_id,
"attributes": {},
"filename": "test_file.txt",
"created_at": int(time.time()),
}
file_contents = [
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == file_contents
@pytest.mark.asyncio
async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
file_info = {
"id": file_id,
"status": "completed",
"vector_store_id": store_id,
"attributes": {},
"filename": "test_file.txt",
"created_at": int(time.time()),
}
file_contents = [
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await milvus_vec_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == []

View file

@ -11,10 +11,16 @@ import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from llama_stack.distribution.datatypes import AuthenticationConfig from llama_stack.distribution.datatypes import (
AuthenticationConfig,
AuthProviderType,
CustomAuthConfig,
OAuth2IntrospectionConfig,
OAuth2JWKSConfig,
OAuth2TokenAuthConfig,
)
from llama_stack.distribution.server.auth import AuthenticationMiddleware from llama_stack.distribution.server.auth import AuthenticationMiddleware
from llama_stack.distribution.server.auth_providers import ( from llama_stack.distribution.server.auth_providers import (
AuthProviderType,
get_attributes_from_claims, get_attributes_from_claims,
) )
@ -61,24 +67,11 @@ def invalid_token():
def http_app(mock_auth_endpoint): def http_app(mock_auth_endpoint):
app = FastAPI() app = FastAPI()
auth_config = AuthenticationConfig( auth_config = AuthenticationConfig(
provider_type=AuthProviderType.CUSTOM, provider_config=CustomAuthConfig(
config={"endpoint": mock_auth_endpoint}, type=AuthProviderType.CUSTOM,
) endpoint=mock_auth_endpoint,
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) ),
access_policy=[],
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def k8s_app():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.KUBERNETES,
config={"api_server_url": "https://kubernetes.default.svc"},
) )
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -94,11 +87,6 @@ def http_client(http_app):
return TestClient(http_app) return TestClient(http_app)
@pytest.fixture
def k8s_client(k8s_app):
return TestClient(k8s_app)
@pytest.fixture @pytest.fixture
def mock_scope(): def mock_scope():
return { return {
@ -117,18 +105,11 @@ def mock_scope():
def mock_http_middleware(mock_auth_endpoint): def mock_http_middleware(mock_auth_endpoint):
mock_app = AsyncMock() mock_app = AsyncMock()
auth_config = AuthenticationConfig( auth_config = AuthenticationConfig(
provider_type=AuthProviderType.CUSTOM, provider_config=CustomAuthConfig(
config={"endpoint": mock_auth_endpoint}, type=AuthProviderType.CUSTOM,
) endpoint=mock_auth_endpoint,
return AuthenticationMiddleware(mock_app, auth_config), mock_app ),
access_policy=[],
@pytest.fixture
def mock_k8s_middleware():
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.KUBERNETES,
config={"api_server_url": "https://kubernetes.default.svc"},
) )
return AuthenticationMiddleware(mock_app, auth_config), mock_app return AuthenticationMiddleware(mock_app, auth_config), mock_app
@ -161,13 +142,14 @@ async def mock_post_exception(*args, **kwargs):
def test_missing_auth_header(http_client): def test_missing_auth_header(http_client):
response = http_client.get("/test") response = http_client.get("/test")
assert response.status_code == 401 assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"] assert "Authentication required" in response.json()["error"]["message"]
assert "validated by mock-auth-service" in response.json()["error"]["message"]
def test_invalid_auth_header_format(http_client): def test_invalid_auth_header_format(http_client):
response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) response = http_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401 assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"] assert "Invalid Authorization header format" in response.json()["error"]["message"]
@patch("httpx.AsyncClient.post", new=mock_post_success) @patch("httpx.AsyncClient.post", new=mock_post_success)
@ -262,14 +244,14 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
def oauth2_app(): def oauth2_app():
app = FastAPI() app = FastAPI()
auth_config = AuthenticationConfig( auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN, provider_config=OAuth2TokenAuthConfig(
config={ type=AuthProviderType.OAUTH2_TOKEN,
"jwks": { jwks=OAuth2JWKSConfig(
"uri": "http://mock-authz-service/token/introspect", uri="http://mock-authz-service/token/introspect",
"key_recheck_period": "3600", ),
}, audience="llama-stack",
"audience": "llama-stack", ),
}, access_policy=[],
) )
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -288,13 +270,14 @@ def oauth2_client(oauth2_app):
def test_missing_auth_header_oauth2(oauth2_client): def test_missing_auth_header_oauth2(oauth2_client):
response = oauth2_client.get("/test") response = oauth2_client.get("/test")
assert response.status_code == 401 assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"] assert "Authentication required" in response.json()["error"]["message"]
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
def test_invalid_auth_header_format_oauth2(oauth2_client): def test_invalid_auth_header_format_oauth2(oauth2_client):
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401 assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"] assert "Invalid Authorization header format" in response.json()["error"]["message"]
async def mock_jwks_response(*args, **kwargs): async def mock_jwks_response(*args, **kwargs):
@ -358,15 +341,16 @@ async def mock_auth_jwks_response(*args, **kwargs):
def oauth2_app_with_jwks_token(): def oauth2_app_with_jwks_token():
app = FastAPI() app = FastAPI()
auth_config = AuthenticationConfig( auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN, provider_config=OAuth2TokenAuthConfig(
config={ type=AuthProviderType.OAUTH2_TOKEN,
"jwks": { jwks=OAuth2JWKSConfig(
"uri": "http://mock-authz-service/token/introspect", uri="http://mock-authz-service/token/introspect",
"key_recheck_period": "3600", key_recheck_period=3600,
"token": "my-jwks-token", token="my-jwks-token",
}, ),
"audience": "llama-stack", audience="llama-stack",
}, ),
access_policy=[],
) )
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -449,11 +433,15 @@ def mock_introspection_endpoint():
def introspection_app(mock_introspection_endpoint): def introspection_app(mock_introspection_endpoint):
app = FastAPI() app = FastAPI()
auth_config = AuthenticationConfig( auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN, provider_config=OAuth2TokenAuthConfig(
config={ type=AuthProviderType.OAUTH2_TOKEN,
"jwks": None, introspection=OAuth2IntrospectionConfig(
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"}, url=mock_introspection_endpoint,
}, client_id="myclient",
client_secret="abcdefg",
),
),
access_policy=[],
) )
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -468,22 +456,22 @@ def introspection_app(mock_introspection_endpoint):
def introspection_app_with_custom_mapping(mock_introspection_endpoint): def introspection_app_with_custom_mapping(mock_introspection_endpoint):
app = FastAPI() app = FastAPI()
auth_config = AuthenticationConfig( auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN, provider_config=OAuth2TokenAuthConfig(
config={ type=AuthProviderType.OAUTH2_TOKEN,
"jwks": None, introspection=OAuth2IntrospectionConfig(
"introspection": { url=mock_introspection_endpoint,
"url": mock_introspection_endpoint, client_id="myclient",
"client_id": "myclient", client_secret="abcdefg",
"client_secret": "abcdefg", send_secret_in_body=True,
"send_secret_in_body": "true", ),
}, claims_mapping={
"claims_mapping": {
"sub": "roles", "sub": "roles",
"scope": "roles", "scope": "roles",
"groups": "teams", "groups": "teams",
"aud": "namespaces", "aud": "namespaces",
}, },
}, ),
access_policy=[],
) )
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@ -507,13 +495,14 @@ def introspection_client_with_custom_mapping(introspection_app_with_custom_mappi
def test_missing_auth_header_introspection(introspection_client): def test_missing_auth_header_introspection(introspection_client):
response = introspection_client.get("/test") response = introspection_client.get("/test")
assert response.status_code == 401 assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"] assert "Authentication required" in response.json()["error"]["message"]
assert "OAuth2 Bearer token" in response.json()["error"]["message"]
def test_invalid_auth_header_format_introspection(introspection_client): def test_invalid_auth_header_format_introspection(introspection_client):
response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"}) response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401 assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"] assert "Invalid Authorization header format" in response.json()["error"]["message"]
async def mock_introspection_active(*args, **kwargs): async def mock_introspection_active(*args, **kwargs):

View file

@ -0,0 +1,200 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig
from llama_stack.distribution.server.auth import AuthenticationMiddleware
class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
self._json_data = json_data
def json(self):
return self._json_data
def raise_for_status(self):
if self.status_code != 200:
# Create a mock request for the HTTPStatusError
mock_request = httpx.Request("GET", "https://api.github.com/user")
raise httpx.HTTPStatusError(f"HTTP error: {self.status_code}", request=mock_request, response=self)
@pytest.fixture
def github_token_app():
app = FastAPI()
# Configure GitHub token auth
auth_config = AuthenticationConfig(
provider_config=GitHubTokenAuthConfig(
type=AuthProviderType.GITHUB_TOKEN,
github_api_base_url="https://api.github.com",
claims_mapping={
"login": "username",
"id": "user_id",
"organizations": "teams",
},
),
access_policy=[],
)
# Add auth middleware
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def github_token_client(github_token_app):
return TestClient(github_token_app)
def test_authenticated_endpoint_without_token(github_token_client):
"""Test accessing protected endpoint without token"""
response = github_token_client.get("/test")
assert response.status_code == 401
assert "Authentication required" in response.json()["error"]["message"]
assert "GitHub access token" in response.json()["error"]["message"]
def test_authenticated_endpoint_with_invalid_bearer_format(github_token_client):
"""Test accessing protected endpoint with invalid bearer format"""
response = github_token_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Invalid Authorization header format" in response.json()["error"]["message"]
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
def test_authenticated_endpoint_with_valid_github_token(mock_client_class, github_token_client):
"""Test accessing protected endpoint with valid GitHub token"""
# Mock the GitHub API responses
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Mock successful user API response
mock_client.get.side_effect = [
MockResponse(
200,
{
"login": "testuser",
"id": 12345,
"email": "test@example.com",
"name": "Test User",
},
),
MockResponse(
200,
[
{"login": "test-org-1"},
{"login": "test-org-2"},
],
),
]
response = github_token_client.get("/test", headers={"Authorization": "Bearer github_token_123"})
assert response.status_code == 200
assert response.json()["message"] == "Authentication successful"
# Verify the GitHub API was called correctly
assert mock_client.get.call_count == 1
calls = mock_client.get.call_args_list
assert calls[0][0][0] == "https://api.github.com/user"
# Check authorization header was passed
assert calls[0][1]["headers"]["Authorization"] == "Bearer github_token_123"
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
"""Test accessing protected endpoint with invalid GitHub token"""
# Mock the GitHub API to return 401 Unauthorized
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Mock failed user API response
mock_client.get.return_value = MockResponse(401, {"message": "Bad credentials"})
response = github_token_client.get("/test", headers={"Authorization": "Bearer invalid_token"})
assert response.status_code == 401
assert (
"GitHub token validation failed. Please check your token and try again." in response.json()["error"]["message"]
)
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
def test_github_enterprise_support(mock_client_class):
"""Test GitHub Enterprise support with custom API base URL"""
app = FastAPI()
# Configure GitHub token auth with enterprise URL
auth_config = AuthenticationConfig(
provider_config=GitHubTokenAuthConfig(
type=AuthProviderType.GITHUB_TOKEN,
github_api_base_url="https://github.enterprise.com/api/v3",
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
client = TestClient(app)
# Mock the GitHub Enterprise API responses
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Mock successful user API response
mock_client.get.side_effect = [
MockResponse(
200,
{
"login": "enterprise_user",
"id": 99999,
"email": "user@enterprise.com",
},
),
MockResponse(
200,
[
{"login": "enterprise-org"},
],
),
]
response = client.get("/test", headers={"Authorization": "Bearer enterprise_token"})
assert response.status_code == 200
# Verify the correct GitHub Enterprise URLs were called
assert mock_client.get.call_count == 1
calls = mock_client.get.call_args_list
assert calls[0][0][0] == "https://github.enterprise.com/api/v3/user"
def test_github_token_auth_error_message_format(github_token_client):
"""Test that the error message for missing auth is properly formatted"""
response = github_token_client.get("/test")
assert response.status_code == 401
error_data = response.json()
assert "error" in error_data
assert "message" in error_data["error"]
assert "Authentication required" in error_data["error"]["message"]
assert "https://docs.github.com" in error_data["error"]["message"] # Contains link to GitHub docs

View file

@ -104,19 +104,17 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
# Test scenarios with different access control patterns # Test scenarios with different access control patterns
test_scenarios = [ test_scenarios = [
# Scenario 1: Public record (no access control) # Scenario 1: Public record (no access control - represents None user insert)
{"id": "1", "name": "public", "access_attributes": None}, {"id": "1", "name": "public", "access_attributes": None},
# Scenario 2: Empty access control (should be treated as public) # Scenario 2: Record with roles requirement
{"id": "2", "name": "empty", "access_attributes": {}}, {"id": "2", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
# Scenario 3: Record with roles requirement # Scenario 3: Record with multiple attribute categories
{"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}}, {"id": "3", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
# Scenario 4: Record with multiple attribute categories # Scenario 4: Record with teams only (missing roles category)
{"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}}, {"id": "4", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
# Scenario 5: Record with teams only (missing roles category) # Scenario 5: Record with roles and projects
{"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
# Scenario 6: Record with roles and projects
{ {
"id": "6", "id": "5",
"name": "admin-project-x", "name": "admin-project-x",
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]}, "access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
}, },

2976
uv.lock generated

File diff suppressed because it is too large Load diff