diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index af2058b9a..263828e1c 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,10 +1,8 @@ # What does this PR do? -[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] + -[//]: # (If resolving an issue, uncomment and update the line below) -[//]: # (Closes #[issue-number]) + + ## Test Plan -[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] - -[//]: # (## Documentation) + diff --git a/.github/actions/setup-runner/action.yml b/.github/actions/setup-runner/action.yml new file mode 100644 index 000000000..6cba4fdc3 --- /dev/null +++ b/.github/actions/setup-runner/action.yml @@ -0,0 +1,22 @@ +name: Setup runner +description: Prepare a runner for the tests (install uv, python, project dependencies, etc.) +runs: + using: "composite" + steps: + - name: Install uv + uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 + with: + python-version: "3.10" + activate-environment: true + version: 0.7.6 + + - name: Install dependencies + shell: bash + run: | + uv sync --all-groups + uv pip install ollama faiss-cpu + # always test against the latest version of the client + # TODO: this is not necessarily a good idea. we need to test against both published and latest + # to find out backwards compatibility issues. + uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main + uv pip install -e . diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 82a76ad32..a3a746246 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -23,23 +23,18 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - auth-provider: [kubernetes] + auth-provider: [oauth2_token] fail-fast: false # we want to run all tests regardless of failure steps: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - activate-environment: true + - name: Install dependencies + uses: ./.github/actions/setup-runner - - name: Set Up Environment and Install Dependencies + - name: Build Llama Stack run: | - uv sync --extra dev --extra test - uv pip install -e . llama stack build --template ollama --image-type venv - name: Install minikube @@ -47,29 +42,53 @@ jobs: uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19 - name: Start minikube - if: ${{ matrix.auth-provider == 'kubernetes' }} + if: ${{ matrix.auth-provider == 'oauth2_token' }} run: | minikube start kubectl get pods -A - name: Configure Kube Auth - if: ${{ matrix.auth-provider == 'kubernetes' }} + if: ${{ matrix.auth-provider == 'oauth2_token' }} run: | kubectl create namespace llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token + cat <> $GITHUB_ENV + echo "KUBERNETES_API_SERVER_URL=$(kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri)" >> $GITHUB_ENV echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV + echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV + echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV - name: Set Kube Auth Config and run server env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" - if: ${{ matrix.auth-provider == 'kubernetes' }} + if: ${{ matrix.auth-provider == 'oauth2_token' }} run: | run_dir=$(mktemp -d) cat <<'EOF' > $run_dir/run.yaml @@ -81,10 +100,10 @@ jobs: port: 8321 EOF yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml - yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml + yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml + yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml cat $run_dir/run.yaml - source .venv/bin/activate nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & - name: Wait for Llama Stack server to be ready diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index da41e2185..d78e82c9d 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -24,7 +24,7 @@ jobs: matrix: # Listing tests manually since some of them currently fail # TODO: generate matrix list from tests/integration when fixed - test-type: [agents, inference, datasets, inspect, scoring, post_training, providers] + test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime] client-type: [library, http] fail-fast: false # we want to run all tests regardless of failure @@ -32,24 +32,14 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - activate-environment: true + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Setup ollama uses: ./.github/actions/setup-ollama - - name: Set Up Environment and Install Dependencies + - name: Build Llama Stack run: | - uv sync --extra dev --extra test - uv pip install ollama faiss-cpu - # always test against the latest version of the client - # TODO: this is not necessarily a good idea. we need to test against both published and latest - # to find out backwards compatibility issues. - uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main - uv pip install -e . llama stack build --template ollama --image-type venv - name: Start Llama Stack server in background @@ -57,7 +47,6 @@ jobs: env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" run: | - source .venv/bin/activate LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv & - name: Wait for Llama Stack server to be ready @@ -85,6 +74,7 @@ jobs: echo "Ollama health check failed" exit 1 fi + - name: Check Storage and Memory Available Before Tests if: ${{ always() }} run: | @@ -100,7 +90,7 @@ jobs: else stack_config="http://localhost:8321" fi - uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ + uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ --text-model="meta-llama/Llama-3.2-3B-Instruct" \ --embedding-model=all-MiniLM-L6-v2 diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 4df04fbb0..2bbd52c53 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -29,6 +29,7 @@ jobs: - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 env: SKIP: no-commit-to-branch + RUFF_OUTPUT_FORMAT: github - name: Verify if there are any diff files after pre-commit run: | diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 3c1682833..cf53459b9 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -50,21 +50,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.10' - - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - - - name: Install LlamaStack - run: | - uv venv - source .venv/bin/activate - uv pip install -e . + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Print build dependencies run: | @@ -79,7 +66,6 @@ jobs: - name: Print dependencies in the image if: matrix.image-type == 'venv' run: | - source test/bin/activate uv pip list build-single-provider: @@ -88,21 +74,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.10' - - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - - - name: Install LlamaStack - run: | - uv venv - source .venv/bin/activate - uv pip install -e . + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Build a single provider run: | @@ -114,21 +87,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.10' - - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - - - name: Install LlamaStack - run: | - uv venv - source .venv/bin/activate - uv pip install -e . + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Build a single provider run: | @@ -152,21 +112,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.10' - - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - - - name: Install LlamaStack - run: | - uv venv - source .venv/bin/activate - uv pip install -e . + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Pin template to UBI9 base run: | diff --git a/.github/workflows/test-external-providers.yml b/.github/workflows/test-external-providers.yml index 2e18fc5eb..06ab7cf3c 100644 --- a/.github/workflows/test-external-providers.yml +++ b/.github/workflows/test-external-providers.yml @@ -25,15 +25,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: "3.10" - - - name: Set Up Environment and Install Dependencies - run: | - uv sync --extra dev --extra test - uv pip install -e . + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Apply image type to config file run: | @@ -59,7 +52,6 @@ jobs: env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" run: | - source ci-test/bin/activate uv run pip list nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index d2dd34e05..fc0459f0f 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -30,17 +30,11 @@ jobs: - "3.12" - "3.13" steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: ${{ matrix.python }} - - - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - with: - python-version: ${{ matrix.python }} - enable-cache: false + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Run unit tests run: | diff --git a/.github/workflows/update-readthedocs.yml b/.github/workflows/update-readthedocs.yml index 04e23bca9..981332a77 100644 --- a/.github/workflows/update-readthedocs.yml +++ b/.github/workflows/update-readthedocs.yml @@ -37,16 +37,8 @@ jobs: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.11' - - - name: Install the latest version of uv - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 - - - name: Sync with uv - run: uv sync --extra docs + - name: Install dependencies + uses: ./.github/actions/setup-runner - name: Build HTML run: | diff --git a/.gitignore b/.gitignore index 6884e37e1..2ac7e1e3a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ dev_requirements.txt build .DS_Store llama_stack/configs/* +.cursor/ xcuserdata/ *.hmap .DS_Store diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e78fcd158..aaec469e4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,7 +53,7 @@ repos: - black==24.3.0 - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.6.3 + rev: 0.7.8 hooks: - id: uv-lock - id: uv-export @@ -61,6 +61,7 @@ repos: "--frozen", "--no-hashes", "--no-emit-project", + "--no-default-groups", "--output-file=requirements.txt" ] @@ -88,20 +89,17 @@ repos: - id: distro-codegen name: Distribution Template Codegen additional_dependencies: - - uv==0.6.0 - entry: uv run --extra codegen ./scripts/distro_codegen.py + - uv==0.7.8 + entry: uv run --group codegen ./scripts/distro_codegen.py language: python pass_filenames: false require_serial: true files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$ - -- repo: local - hooks: - id: openapi-codegen name: API Spec Codegen additional_dependencies: - - uv==0.6.2 - entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null' + - uv==0.7.8 + entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null' language: python pass_filenames: false require_serial: true diff --git a/CHANGELOG.md b/CHANGELOG.md index a0b008982..f7644a5af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,26 @@ # Changelog +# v0.2.7 +Published on: 2025-05-16T20:38:10Z + +## Highlights + +This is a small update. But a couple highlights: + +* feat: function tools in OpenAI Responses by @bbrowning in https://github.com/meta-llama/llama-stack/pull/2094, getting closer to ready. Streaming is the next missing piece. +* feat: Adding support for customizing chunk context in RAG insertion and querying by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2134 +* feat: scaffolding for Llama Stack UI by @ehhuang in https://github.com/meta-llama/llama-stack/pull/2149, more to come in the coming releases. + + +--- + +# v0.2.6 +Published on: 2025-05-12T18:06:52Z + + + +--- + # v0.2.5 Published on: 2025-05-04T20:16:49Z diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d7c3e3e2f..8f71a6ba1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -167,14 +167,11 @@ If you have made changes to a provider's configuration in any form (introducing If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme. ```bash -cd docs -uv sync --extra docs - # This rebuilds the documentation pages. -uv run make html +uv run --with ".[docs]" make -C docs/ html # This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation. -uv run sphinx-autobuild source build/html --write-all +uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all ``` ### Update API Documentation diff --git a/MANIFEST.in b/MANIFEST.in index 879a9cbd4..88bd11767 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,4 @@ include pyproject.toml -include llama_stack/templates/dependencies.json include llama_stack/models/llama/llama3/tokenizer.model include llama_stack/models/llama/llama4/tokenizer.model include llama_stack/distribution/*.sh diff --git a/README.md b/README.md index 5dfe3577a..e54b505cf 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ Here is a list of the various API providers and available distributions that can | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | |:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:| | Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | -| SambaNova | Hosted | | ✅ | | | | +| SambaNova | Hosted | | ✅ | | ✅ | | | Cerebras | Hosted | | ✅ | | | | | Fireworks | Hosted | ✅ | ✅ | ✅ | | | | AWS Bedrock | Hosted | | ✅ | | ✅ | | diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 6378a5ced..9c1c3170f 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -518,6 +518,74 @@ } }, "/v1/openai/v1/responses": { + "get": { + "responses": { + "200": { + "description": "A ListOpenAIResponseObject.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListOpenAIResponseObject" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Agents" + ], + "description": "List all OpenAI responses.", + "parameters": [ + { + "name": "after", + "in": "query", + "description": "The ID of the last response to return.", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "limit", + "in": "query", + "description": "The number of responses to return.", + "required": false, + "schema": { + "type": "integer" + } + }, + { + "name": "model", + "in": "query", + "description": "The model to filter responses by.", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "order", + "in": "query", + "description": "The order to sort responses by when sorted by created_at ('asc' or 'desc').", + "required": false, + "schema": { + "$ref": "#/components/schemas/Order" + } + } + ] + }, "post": { "responses": { "200": { @@ -1395,7 +1463,7 @@ ] } }, - "/v1/openai/v1/responses/{id}": { + "/v1/openai/v1/responses/{response_id}": { "get": { "responses": { "200": { @@ -1427,7 +1495,7 @@ "description": "Retrieve an OpenAI response by its ID.", "parameters": [ { - "name": "id", + "name": "response_id", "in": "path", "description": "The ID of the OpenAI response to retrieve.", "required": true, @@ -2926,6 +2994,97 @@ } } }, + "/v1/openai/v1/responses/{response_id}/input_items": { + "get": { + "responses": { + "200": { + "description": "An ListOpenAIResponseInputItem.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListOpenAIResponseInputItem" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Agents" + ], + "description": "List input items for a given OpenAI response.", + "parameters": [ + { + "name": "response_id", + "in": "path", + "description": "The ID of the response to retrieve input items for.", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "after", + "in": "query", + "description": "An item ID to list items after, used for pagination.", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "before", + "in": "query", + "description": "An item ID to list items before, used for pagination.", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "include", + "in": "query", + "description": "Additional fields to include in the response.", + "required": false, + "schema": { + "type": "array", + "items": { + "type": "string" + } + } + }, + { + "name": "limit", + "in": "query", + "description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.", + "required": false, + "schema": { + "type": "integer" + } + }, + { + "name": "order", + "in": "query", + "description": "The order to return the input items in. Default is desc.", + "required": false, + "schema": { + "$ref": "#/components/schemas/Order" + } + } + ] + } + }, "/v1/providers": { "get": { "responses": { @@ -6742,6 +6901,9 @@ }, { "$ref": "#/components/schemas/OpenAIResponseInputToolFunction" + }, + { + "$ref": "#/components/schemas/OpenAIResponseInputToolMCP" } ], "discriminator": { @@ -6749,7 +6911,8 @@ "mapping": { "web_search": "#/components/schemas/OpenAIResponseInputToolWebSearch", "file_search": "#/components/schemas/OpenAIResponseInputToolFileSearch", - "function": "#/components/schemas/OpenAIResponseInputToolFunction" + "function": "#/components/schemas/OpenAIResponseInputToolFunction", + "mcp": "#/components/schemas/OpenAIResponseInputToolMCP" } } }, @@ -6839,6 +7002,110 @@ ], "title": "OpenAIResponseInputToolFunction" }, + "OpenAIResponseInputToolMCP": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "mcp", + "default": "mcp" + }, + "server_label": { + "type": "string" + }, + "server_url": { + "type": "string" + }, + "headers": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "require_approval": { + "oneOf": [ + { + "type": "string", + "const": "always" + }, + { + "type": "string", + "const": "never" + }, + { + "type": "object", + "properties": { + "always": { + "type": "array", + "items": { + "type": "string" + } + }, + "never": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "title": "ApprovalFilter" + } + ], + "default": "never" + }, + "allowed_tools": { + "oneOf": [ + { + "type": "array", + "items": { + "type": "string" + } + }, + { + "type": "object", + "properties": { + "tool_names": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "title": "AllowedToolsFilter" + } + ] + } + }, + "additionalProperties": false, + "required": [ + "type", + "server_label", + "server_url", + "require_approval" + ], + "title": "OpenAIResponseInputToolMCP" + }, "OpenAIResponseInputToolWebSearch": { "type": "object", "properties": { @@ -6951,15 +7218,15 @@ "OpenAIResponseOutputMessageFunctionToolCall": { "type": "object", "properties": { - "arguments": { - "type": "string" - }, "call_id": { "type": "string" }, "name": { "type": "string" }, + "arguments": { + "type": "string" + }, "type": { "type": "string", "const": "function_call", @@ -6974,12 +7241,10 @@ }, "additionalProperties": false, "required": [ - "arguments", "call_id", "name", - "type", - "id", - "status" + "arguments", + "type" ], "title": "OpenAIResponseOutputMessageFunctionToolCall" }, @@ -7027,6 +7292,9 @@ "type": "string", "description": "The underlying LLM used for completions." }, + "instructions": { + "type": "string" + }, "previous_response_id": { "type": "string", "description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses." @@ -7142,6 +7410,12 @@ }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" + }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall" + }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" } ], "discriminator": { @@ -7149,15 +7423,126 @@ "mapping": { "message": "#/components/schemas/OpenAIResponseMessage", "web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall", - "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" + "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", + "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", + "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" } } }, + "OpenAIResponseOutputMessageMCPCall": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "mcp_call", + "default": "mcp_call" + }, + "arguments": { + "type": "string" + }, + "name": { + "type": "string" + }, + "server_label": { + "type": "string" + }, + "error": { + "type": "string" + }, + "output": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "id", + "type", + "arguments", + "name", + "server_label" + ], + "title": "OpenAIResponseOutputMessageMCPCall" + }, + "OpenAIResponseOutputMessageMCPListTools": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "mcp_list_tools", + "default": "mcp_list_tools" + }, + "server_label": { + "type": "string" + }, + "tools": { + "type": "array", + "items": { + "type": "object", + "properties": { + "input_schema": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "input_schema", + "name" + ], + "title": "MCPListToolsTool" + } + } + }, + "additionalProperties": false, + "required": [ + "id", + "type", + "server_label", + "tools" + ], + "title": "OpenAIResponseOutputMessageMCPListTools" + }, "OpenAIResponseObjectStream": { "oneOf": [ { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated" }, + { + "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta" + }, { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } @@ -7166,6 +7551,7 @@ "propertyName": "type", "mapping": { "response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated", + "response.output_text.delta": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta", "response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } } @@ -7208,6 +7594,41 @@ ], "title": "OpenAIResponseObjectStreamResponseCreated" }, + "OpenAIResponseObjectStreamResponseOutputTextDelta": { + "type": "object", + "properties": { + "content_index": { + "type": "integer" + }, + "delta": { + "type": "string" + }, + "item_id": { + "type": "string" + }, + "output_index": { + "type": "integer" + }, + "sequence_number": { + "type": "integer" + }, + "type": { + "type": "string", + "const": "response.output_text.delta", + "default": "response.output_text.delta" + } + }, + "additionalProperties": false, + "required": [ + "content_index", + "delta", + "item_id", + "output_index", + "sequence_number", + "type" + ], + "title": "OpenAIResponseObjectStreamResponseOutputTextDelta" + }, "CreateUploadSessionRequest": { "type": "object", "properties": { @@ -9173,9 +9594,6 @@ "toolgroup_id": { "type": "string" }, - "tool_host": { - "$ref": "#/components/schemas/ToolHost" - }, "description": { "type": "string" }, @@ -9217,21 +9635,11 @@ "provider_id", "type", "toolgroup_id", - "tool_host", "description", "parameters" ], "title": "Tool" }, - "ToolHost": { - "type": "string", - "enum": [ - "distribution", - "client", - "model_context_protocol" - ], - "title": "ToolHost" - }, "ToolGroup": { "type": "object", "properties": { @@ -10068,6 +10476,130 @@ ], "title": "ListModelsResponse" }, + "ListOpenAIResponseInputItem": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIResponseInput" + } + }, + "object": { + "type": "string", + "const": "list", + "default": "list" + } + }, + "additionalProperties": false, + "required": [ + "data", + "object" + ], + "title": "ListOpenAIResponseInputItem" + }, + "ListOpenAIResponseObject": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIResponseObjectWithInput" + } + }, + "has_more": { + "type": "boolean" + }, + "first_id": { + "type": "string" + }, + "last_id": { + "type": "string" + }, + "object": { + "type": "string", + "const": "list", + "default": "list" + } + }, + "additionalProperties": false, + "required": [ + "data", + "has_more", + "first_id", + "last_id", + "object" + ], + "title": "ListOpenAIResponseObject" + }, + "OpenAIResponseObjectWithInput": { + "type": "object", + "properties": { + "created_at": { + "type": "integer" + }, + "error": { + "$ref": "#/components/schemas/OpenAIResponseError" + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "object": { + "type": "string", + "const": "response", + "default": "response" + }, + "output": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIResponseOutput" + } + }, + "parallel_tool_calls": { + "type": "boolean", + "default": false + }, + "previous_response_id": { + "type": "string" + }, + "status": { + "type": "string" + }, + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + }, + "truncation": { + "type": "string" + }, + "user": { + "type": "string" + }, + "input": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIResponseInput" + } + } + }, + "additionalProperties": false, + "required": [ + "created_at", + "id", + "model", + "object", + "output", + "parallel_tool_calls", + "status", + "input" + ], + "title": "OpenAIResponseObjectWithInput" + }, "ListProvidersResponse": { "type": "object", "properties": { @@ -11605,6 +12137,10 @@ "type": "string", "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" + }, + "mode": { + "type": "string", + "description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 012610d02..1afe870cf 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -349,6 +349,53 @@ paths: $ref: '#/components/schemas/CreateAgentTurnRequest' required: true /v1/openai/v1/responses: + get: + responses: + '200': + description: A ListOpenAIResponseObject. + content: + application/json: + schema: + $ref: '#/components/schemas/ListOpenAIResponseObject' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Agents + description: List all OpenAI responses. + parameters: + - name: after + in: query + description: The ID of the last response to return. + required: false + schema: + type: string + - name: limit + in: query + description: The number of responses to return. + required: false + schema: + type: integer + - name: model + in: query + description: The model to filter responses by. + required: false + schema: + type: string + - name: order + in: query + description: >- + The order to sort responses by when sorted by created_at ('asc' or 'desc'). + required: false + schema: + $ref: '#/components/schemas/Order' post: responses: '200': @@ -963,7 +1010,7 @@ paths: required: true schema: type: string - /v1/openai/v1/responses/{id}: + /v1/openai/v1/responses/{response_id}: get: responses: '200': @@ -986,7 +1033,7 @@ paths: - Agents description: Retrieve an OpenAI response by its ID. parameters: - - name: id + - name: response_id in: path description: >- The ID of the OpenAI response to retrieve. @@ -2038,6 +2085,75 @@ paths: schema: $ref: '#/components/schemas/RegisterModelRequest' required: true + /v1/openai/v1/responses/{response_id}/input_items: + get: + responses: + '200': + description: An ListOpenAIResponseInputItem. + content: + application/json: + schema: + $ref: '#/components/schemas/ListOpenAIResponseInputItem' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Agents + description: >- + List input items for a given OpenAI response. + parameters: + - name: response_id + in: path + description: >- + The ID of the response to retrieve input items for. + required: true + schema: + type: string + - name: after + in: query + description: >- + An item ID to list items after, used for pagination. + required: false + schema: + type: string + - name: before + in: query + description: >- + An item ID to list items before, used for pagination. + required: false + schema: + type: string + - name: include + in: query + description: >- + Additional fields to include in the response. + required: false + schema: + type: array + items: + type: string + - name: limit + in: query + description: >- + A limit on the number of objects to be returned. Limit can range between + 1 and 100, and the default is 20. + required: false + schema: + type: integer + - name: order + in: query + description: >- + The order to return the input items in. Default is desc. + required: false + schema: + $ref: '#/components/schemas/Order' /v1/providers: get: responses: @@ -4762,12 +4878,14 @@ components: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolFileSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolFunction' + - $ref: '#/components/schemas/OpenAIResponseInputToolMCP' discriminator: propertyName: type mapping: web_search: '#/components/schemas/OpenAIResponseInputToolWebSearch' file_search: '#/components/schemas/OpenAIResponseInputToolFileSearch' function: '#/components/schemas/OpenAIResponseInputToolFunction' + mcp: '#/components/schemas/OpenAIResponseInputToolMCP' OpenAIResponseInputToolFileSearch: type: object properties: @@ -4822,6 +4940,66 @@ components: - type - name title: OpenAIResponseInputToolFunction + OpenAIResponseInputToolMCP: + type: object + properties: + type: + type: string + const: mcp + default: mcp + server_label: + type: string + server_url: + type: string + headers: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + require_approval: + oneOf: + - type: string + const: always + - type: string + const: never + - type: object + properties: + always: + type: array + items: + type: string + never: + type: array + items: + type: string + additionalProperties: false + title: ApprovalFilter + default: never + allowed_tools: + oneOf: + - type: array + items: + type: string + - type: object + properties: + tool_names: + type: array + items: + type: string + additionalProperties: false + title: AllowedToolsFilter + additionalProperties: false + required: + - type + - server_label + - server_url + - require_approval + title: OpenAIResponseInputToolMCP OpenAIResponseInputToolWebSearch: type: object properties: @@ -4897,12 +5075,12 @@ components: "OpenAIResponseOutputMessageFunctionToolCall": type: object properties: - arguments: - type: string call_id: type: string name: type: string + arguments: + type: string type: type: string const: function_call @@ -4913,12 +5091,10 @@ components: type: string additionalProperties: false required: - - arguments - call_id - name + - arguments - type - - id - - status title: >- OpenAIResponseOutputMessageFunctionToolCall "OpenAIResponseOutputMessageWebSearchToolCall": @@ -4952,6 +5128,8 @@ components: model: type: string description: The underlying LLM used for completions. + instructions: + type: string previous_response_id: type: string description: >- @@ -5034,20 +5212,95 @@ components: - $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' discriminator: propertyName: type mapping: message: '#/components/schemas/OpenAIResponseMessage' web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' + mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' + mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' + OpenAIResponseOutputMessageMCPCall: + type: object + properties: + id: + type: string + type: + type: string + const: mcp_call + default: mcp_call + arguments: + type: string + name: + type: string + server_label: + type: string + error: + type: string + output: + type: string + additionalProperties: false + required: + - id + - type + - arguments + - name + - server_label + title: OpenAIResponseOutputMessageMCPCall + OpenAIResponseOutputMessageMCPListTools: + type: object + properties: + id: + type: string + type: + type: string + const: mcp_list_tools + default: mcp_list_tools + server_label: + type: string + tools: + type: array + items: + type: object + properties: + input_schema: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + name: + type: string + description: + type: string + additionalProperties: false + required: + - input_schema + - name + title: MCPListToolsTool + additionalProperties: false + required: + - id + - type + - server_label + - tools + title: OpenAIResponseOutputMessageMCPListTools OpenAIResponseObjectStream: oneOf: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' + - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' discriminator: propertyName: type mapping: response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' + response.output_text.delta: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta' response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' "OpenAIResponseObjectStreamResponseCompleted": type: object @@ -5079,6 +5332,33 @@ components: - type title: >- OpenAIResponseObjectStreamResponseCreated + "OpenAIResponseObjectStreamResponseOutputTextDelta": + type: object + properties: + content_index: + type: integer + delta: + type: string + item_id: + type: string + output_index: + type: integer + sequence_number: + type: integer + type: + type: string + const: response.output_text.delta + default: response.output_text.delta + additionalProperties: false + required: + - content_index + - delta + - item_id + - output_index + - sequence_number + - type + title: >- + OpenAIResponseObjectStreamResponseOutputTextDelta CreateUploadSessionRequest: type: object properties: @@ -6462,8 +6742,6 @@ components: default: tool toolgroup_id: type: string - tool_host: - $ref: '#/components/schemas/ToolHost' description: type: string parameters: @@ -6486,17 +6764,9 @@ components: - provider_id - type - toolgroup_id - - tool_host - description - parameters title: Tool - ToolHost: - type: string - enum: - - distribution - - client - - model_context_protocol - title: ToolHost ToolGroup: type: object properties: @@ -7042,6 +7312,96 @@ components: required: - data title: ListModelsResponse + ListOpenAIResponseInputItem: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/OpenAIResponseInput' + object: + type: string + const: list + default: list + additionalProperties: false + required: + - data + - object + title: ListOpenAIResponseInputItem + ListOpenAIResponseObject: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/OpenAIResponseObjectWithInput' + has_more: + type: boolean + first_id: + type: string + last_id: + type: string + object: + type: string + const: list + default: list + additionalProperties: false + required: + - data + - has_more + - first_id + - last_id + - object + title: ListOpenAIResponseObject + OpenAIResponseObjectWithInput: + type: object + properties: + created_at: + type: integer + error: + $ref: '#/components/schemas/OpenAIResponseError' + id: + type: string + model: + type: string + object: + type: string + const: response + default: response + output: + type: array + items: + $ref: '#/components/schemas/OpenAIResponseOutput' + parallel_tool_calls: + type: boolean + default: false + previous_response_id: + type: string + status: + type: string + temperature: + type: number + top_p: + type: number + truncation: + type: string + user: + type: string + input: + type: array + items: + $ref: '#/components/schemas/OpenAIResponseInput' + additionalProperties: false + required: + - created_at + - id + - model + - object + - output + - parallel_tool_calls + - status + - input + title: OpenAIResponseObjectWithInput ListProvidersResponse: type: object properties: @@ -8084,6 +8444,10 @@ components: placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" + mode: + type: string + description: >- + Search mode for retrieval—either "vector" or "keyword". Default "vector". additionalProperties: false required: - query_generator_config diff --git a/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb b/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb index 413b693d1..93f78d268 100644 --- a/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb +++ b/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb @@ -38,12 +38,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, "collapsed": true, - "id": "O9pGVlPIjpix", - "outputId": "e1fbe723-ae31-4630-eb80-4c4f6476d56f" + "id": "O9pGVlPIjpix" }, "outputs": [], "source": [ @@ -55,12 +51,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, "collapsed": true, - "id": "JQpLUSNjlGAM", - "outputId": "2f7fec97-5511-4cae-d51e-6d262fbca19c" + "id": "JQpLUSNjlGAM" }, "outputs": [], "source": [ @@ -70,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -701,7 +693,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "id": "TC_IwIAQo4q-" }, @@ -714,116 +706,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 305, - "referenced_widgets": [ - "feb82e061ee44283b4a46be858ef4cd7", - "78a2d2d4ee3f42f3be42ef4baa298561", - "ba5e6ca09f174ef3a348453cf5cfc24a", - "74b58e4647644c9daf9af488942fdaf4", - "d56e218958a041e286e80f24e400ab0b", - "cab80632b7564a9eb59583e09573c1ee", - "10c0d50d7c204de0b4c8e8f4d3ec0af5", - "626ef2f811ae4e119a0e85cebe92b91d", - "aef4172d916f40b0ab4ed09104e10f24", - "25529e7fd57049d2816d31f696eab1fd", - "093bdcb608cf4b4fa37b0032a3915187", - "c788d4e9e1e24dca9b6503689df9b631", - "d1587e2144bf46299c1bdec3ea96e4e7", - "500a072c09da41759cb2c942a16d8429", - "9785009392934e3bbb229e8781667cbc", - "84570fe2c2a54a068fb9b8cbc8b041a1", - "f9e579c58e3f4ae0bbb721dffa33bf0a", - "737116977f474ec0b68d88a40fd1086c", - "e6d6e516cd03452297d80c36376855dd", - "6ae0fadb3aeb4be18a9ab3279fb23145", - "fa4800a506ac480984d58933580df086", - "117468099dbc42fdaafc08207eaac7ab", - "44f585990aa244d8ba61f892dc1ccc1c", - "4fc59928a0544f95a4438b37d19ca437", - "fb644d47049f495397d0e60597c86ea3", - "78632694ff694442bc3fefc2cac2cbf5", - "083fd2549abd4b03bd41d8b92ec28f42", - "611d6472a58d419583acc416767a4c90", - "98c5ce434cff454eaaa3f0fd3498183a", - "3d0344a9cc744e369da1b6b7ea1b3be8", - "c452ccbf47a44073aee710175f707a7d", - "0218397c573e4b28bfb4ffa66464d50f", - "9b01bcd6e5174be2af19f457047017c8", - "4fed5720f30b4b3cbbc606a4f25e223b", - "6fa866b9971542739b0ed26d90ceac80", - "fe7553b513954cc68c427b5d9d260b33", - "4bc266d49a6741a88350e029d101425b", - "da57445f98e7427589962836c2b4287e", - "ad1fb86cc1f94fd9911eda03cf4a3783", - "fdefb51ad4c4418b98c5826126558011", - "179d41b80dc841e8a440482516b8bca5", - "22b1ecd2eff14770bcfb0c62d3d4213f", - "47f876cf41484d55b645e1e99337423a", - "340fbbb4982c460992c88885e79b47db", - "9659140487ca4d3ea799196d2c1ecf61", - "52150fd494d24eea89b5232077509355", - "04acde771d0a46699e1de07d9733d1a3", - "7b98103300814f3caea84266263b95a2", - "75f06408071c494f934bb909b84110d1", - "b09b2690894749339a9172e5ad0a9b75", - "cbed38801163438d891879b756f5baab", - "399a6417b23e4593bb244ec3abb6b46d", - "53a321f36b0d4e08a74a5bcfbd04434b", - "b8c0c8aaac0d4032bf5c673a43d084ab", - "d1f32499fa3f4795b92361637e23a9bb", - "c06f9a090fb54c74b947634bf6d11fa8", - "82991dcc80f14af9bd2e95f705980676", - "cd832e3842b945aabbb327856053f261", - "93ee645d54f34acdb0d15092d4a6f0d1", - "b77fe05bbcf84cdc8ef85b264ccd35f6", - "e17d286a965a49cfb8d5bf885865cb1e", - "ca015c1a0c1449e68edb282462435a3f", - "2932b06afde9468a976eb6bfb072b80e", - "d027c807ddc04f89bec41dc05fde7718", - "4ff3a6aaf706460bbba01b248b93000e", - "bfd75a39f0154c30adbaad1e2ca0f1e2", - "4f788a7920c346f3b42900825bd6711a", - "8e9358ec7d474808bb96c13e13489c67", - "f0dfeee2a8d64dedbc8ef55ad4e69932", - "9437b707bf1a4847a50aafeb4252dab5", - "f255707788704a76bd1651f26a22402d", - "3b70fa4e43ef4951862e119378c3c501", - "6c0a6a7fa8ca4e1c961a36305f0e7638", - "201bd914f9884e46b8e6df9d9900a6e8", - "f53b7ada01084e73bba6e14a95e2a534", - "d2029292327b488db02fd123ee2b75af", - "3e26bc24a3e44b4582f57913bdf98de4", - "9d2b6eabf7e14436b72bbf374b4a2a0a", - "b5d7cb5a6157449a850ef0e12e3d3eb7", - "c245d316bf9e44dabe5bfd1e47fc8d2e", - "963cf422ca894d82b0dd94c6165d41bf", - "78d0e2aa93674bbeb42bff87a23cce9b", - "12c6f1180eeb4e9eb9037ea5dd24ec8e", - "017a81d7160240a398947545963856f5", - "1cf8eeb8d81c4e8a8e95dd43296a78b9", - "5b0b5a3f79e94c51aae48fe0dd34ba0e", - "f5b34a743ce54fb591f25b04a2651d65", - "dec6399e2c5341aead66e1674d3e6c72", - "24e48376a72940679989a39a40bbe7f6", - "484df732051540859bc7ac9cecadc83c", - "4b33b1db50c34a2fa957d81a71a2a47f", - "e51d501e2f994baba40345ad632eabee", - "631a85e420b64e8cb6915af59c5ce08a", - "70af9cb2838c4a92bd67f8cb5c98d97f", - "158115266c284c4f8dbce3586151cbf1", - "ce5019b36cde44c58c5f596dbb59a2f8", - "b90d660ca8584ba1815a3c66b420c079", - "7c4d1de626784a59a7e0a33c24086186", - "21cf0e35ecd845a8b5e7c5ce241cf177" - ] - }, "collapsed": true, - "id": "DJkmoG2kq1_P", - "outputId": "8493ee59-c6ff-4bb6-d787-f295944db1cf" + "id": "DJkmoG2kq1_P" }, "outputs": [], "source": [ @@ -848,7 +734,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -949,7 +835,7 @@ "\n", "client.benchmarks.register(\n", " benchmark_id=\"meta-reference::mmmu\",\n", - " # Note: we can use any value as `dataset_id` because we'll be using the `evaluate_rows` API which accepts the \n", + " # Note: we can use any value as `dataset_id` because we'll be using the `evaluate_rows` API which accepts the\n", " # `input_rows` argument and does not fetch data from the dataset.\n", " dataset_id=f\"mmmu-{subset}-{split}\",\n", " # Note: for the same reason as above, we can use any value as `scoring_functions`.\n", @@ -994,7 +880,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "id": "HXmZf3Ymw-aX" }, @@ -1014,7 +900,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "id": "Gc8azb4Rxr5J" }, @@ -1028,7 +914,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -1168,7 +1054,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -1293,7 +1179,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "id": "lxc9-eXYK5Av" + }, "outputs": [], "source": [] } @@ -1322,3088 +1210,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "017a81d7160240a398947545963856f5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "0218397c573e4b28bfb4ffa66464d50f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "04acde771d0a46699e1de07d9733d1a3": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_399a6417b23e4593bb244ec3abb6b46d", - "max": 453677660, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_53a321f36b0d4e08a74a5bcfbd04434b", - "value": 453677660 - } - }, - "083fd2549abd4b03bd41d8b92ec28f42": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "093bdcb608cf4b4fa37b0032a3915187": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "10c0d50d7c204de0b4c8e8f4d3ec0af5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "117468099dbc42fdaafc08207eaac7ab": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "12c6f1180eeb4e9eb9037ea5dd24ec8e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "158115266c284c4f8dbce3586151cbf1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "179d41b80dc841e8a440482516b8bca5": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "1cf8eeb8d81c4e8a8e95dd43296a78b9": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "201bd914f9884e46b8e6df9d9900a6e8": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "21cf0e35ecd845a8b5e7c5ce241cf177": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "22b1ecd2eff14770bcfb0c62d3d4213f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "24e48376a72940679989a39a40bbe7f6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_484df732051540859bc7ac9cecadc83c", - "IPY_MODEL_4b33b1db50c34a2fa957d81a71a2a47f", - "IPY_MODEL_e51d501e2f994baba40345ad632eabee" - ], - "layout": "IPY_MODEL_631a85e420b64e8cb6915af59c5ce08a" - } - }, - "25529e7fd57049d2816d31f696eab1fd": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2932b06afde9468a976eb6bfb072b80e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "340fbbb4982c460992c88885e79b47db": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "399a6417b23e4593bb244ec3abb6b46d": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3b70fa4e43ef4951862e119378c3c501": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3d0344a9cc744e369da1b6b7ea1b3be8": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3e26bc24a3e44b4582f57913bdf98de4": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "44f585990aa244d8ba61f892dc1ccc1c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_4fc59928a0544f95a4438b37d19ca437", - "IPY_MODEL_fb644d47049f495397d0e60597c86ea3", - "IPY_MODEL_78632694ff694442bc3fefc2cac2cbf5" - ], - "layout": "IPY_MODEL_083fd2549abd4b03bd41d8b92ec28f42" - } - }, - "47f876cf41484d55b645e1e99337423a": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "484df732051540859bc7ac9cecadc83c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_70af9cb2838c4a92bd67f8cb5c98d97f", - "placeholder": "​", - "style": "IPY_MODEL_158115266c284c4f8dbce3586151cbf1", - "value": "Generating test split: 100%" - } - }, - "4b33b1db50c34a2fa957d81a71a2a47f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_ce5019b36cde44c58c5f596dbb59a2f8", - "max": 287, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_b90d660ca8584ba1815a3c66b420c079", - "value": 287 - } - }, - "4bc266d49a6741a88350e029d101425b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_47f876cf41484d55b645e1e99337423a", - "placeholder": "​", - "style": "IPY_MODEL_340fbbb4982c460992c88885e79b47db", - "value": " 461M/461M [00:11<00:00, 31.2MB/s]" - } - }, - "4f788a7920c346f3b42900825bd6711a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_8e9358ec7d474808bb96c13e13489c67", - "IPY_MODEL_f0dfeee2a8d64dedbc8ef55ad4e69932", - "IPY_MODEL_9437b707bf1a4847a50aafeb4252dab5" - ], - "layout": "IPY_MODEL_f255707788704a76bd1651f26a22402d" - } - }, - "4fc59928a0544f95a4438b37d19ca437": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_611d6472a58d419583acc416767a4c90", - "placeholder": "​", - "style": "IPY_MODEL_98c5ce434cff454eaaa3f0fd3498183a", - "value": "validation-00000-of-00001.parquet: 100%" - } - }, - "4fed5720f30b4b3cbbc606a4f25e223b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_6fa866b9971542739b0ed26d90ceac80", - "IPY_MODEL_fe7553b513954cc68c427b5d9d260b33", - "IPY_MODEL_4bc266d49a6741a88350e029d101425b" - ], - "layout": "IPY_MODEL_da57445f98e7427589962836c2b4287e" - } - }, - "4ff3a6aaf706460bbba01b248b93000e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "500a072c09da41759cb2c942a16d8429": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e6d6e516cd03452297d80c36376855dd", - "max": 29453850, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_6ae0fadb3aeb4be18a9ab3279fb23145", - "value": 29453850 - } - }, - "52150fd494d24eea89b5232077509355": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_b09b2690894749339a9172e5ad0a9b75", - "placeholder": "​", - "style": "IPY_MODEL_cbed38801163438d891879b756f5baab", - "value": "test-00001-of-00003.parquet: 100%" - } - }, - "53a321f36b0d4e08a74a5bcfbd04434b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "5b0b5a3f79e94c51aae48fe0dd34ba0e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "611d6472a58d419583acc416767a4c90": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "626ef2f811ae4e119a0e85cebe92b91d": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "631a85e420b64e8cb6915af59c5ce08a": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6ae0fadb3aeb4be18a9ab3279fb23145": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "6c0a6a7fa8ca4e1c961a36305f0e7638": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6fa866b9971542739b0ed26d90ceac80": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_ad1fb86cc1f94fd9911eda03cf4a3783", - "placeholder": "​", - "style": "IPY_MODEL_fdefb51ad4c4418b98c5826126558011", - "value": "test-00000-of-00003.parquet: 100%" - } - }, - "70af9cb2838c4a92bd67f8cb5c98d97f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "737116977f474ec0b68d88a40fd1086c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "74b58e4647644c9daf9af488942fdaf4": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_25529e7fd57049d2816d31f696eab1fd", - "placeholder": "​", - "style": "IPY_MODEL_093bdcb608cf4b4fa37b0032a3915187", - "value": " 36.0k/36.0k [00:00<00:00, 1.29MB/s]" - } - }, - "75f06408071c494f934bb909b84110d1": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "78632694ff694442bc3fefc2cac2cbf5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_0218397c573e4b28bfb4ffa66464d50f", - "placeholder": "​", - "style": "IPY_MODEL_9b01bcd6e5174be2af19f457047017c8", - "value": " 165M/165M [00:03<00:00, 42.9MB/s]" - } - }, - "78a2d2d4ee3f42f3be42ef4baa298561": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_cab80632b7564a9eb59583e09573c1ee", - "placeholder": "​", - "style": "IPY_MODEL_10c0d50d7c204de0b4c8e8f4d3ec0af5", - "value": "README.md: 100%" - } - }, - "78d0e2aa93674bbeb42bff87a23cce9b": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "7b98103300814f3caea84266263b95a2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_b8c0c8aaac0d4032bf5c673a43d084ab", - "placeholder": "​", - "style": "IPY_MODEL_d1f32499fa3f4795b92361637e23a9bb", - "value": " 454M/454M [00:11<00:00, 40.4MB/s]" - } - }, - "7c4d1de626784a59a7e0a33c24086186": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "82991dcc80f14af9bd2e95f705980676": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e17d286a965a49cfb8d5bf885865cb1e", - "placeholder": "​", - "style": "IPY_MODEL_ca015c1a0c1449e68edb282462435a3f", - "value": "test-00002-of-00003.parquet: 100%" - } - }, - "84570fe2c2a54a068fb9b8cbc8b041a1": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "8e9358ec7d474808bb96c13e13489c67": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_3b70fa4e43ef4951862e119378c3c501", - "placeholder": "​", - "style": "IPY_MODEL_6c0a6a7fa8ca4e1c961a36305f0e7638", - "value": "Generating dev split: 100%" - } - }, - "93ee645d54f34acdb0d15092d4a6f0d1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_4ff3a6aaf706460bbba01b248b93000e", - "placeholder": "​", - "style": "IPY_MODEL_bfd75a39f0154c30adbaad1e2ca0f1e2", - "value": " 471M/471M [00:11<00:00, 41.5MB/s]" - } - }, - "9437b707bf1a4847a50aafeb4252dab5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_d2029292327b488db02fd123ee2b75af", - "placeholder": "​", - "style": "IPY_MODEL_3e26bc24a3e44b4582f57913bdf98de4", - "value": " 5/5 [00:00<00:00,  8.03 examples/s]" - } - }, - "963cf422ca894d82b0dd94c6165d41bf": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_f5b34a743ce54fb591f25b04a2651d65", - "placeholder": "​", - "style": "IPY_MODEL_dec6399e2c5341aead66e1674d3e6c72", - "value": " 30/30 [00:03<00:00,  8.23 examples/s]" - } - }, - "9659140487ca4d3ea799196d2c1ecf61": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_52150fd494d24eea89b5232077509355", - "IPY_MODEL_04acde771d0a46699e1de07d9733d1a3", - "IPY_MODEL_7b98103300814f3caea84266263b95a2" - ], - "layout": "IPY_MODEL_75f06408071c494f934bb909b84110d1" - } - }, - "9785009392934e3bbb229e8781667cbc": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_fa4800a506ac480984d58933580df086", - "placeholder": "​", - "style": "IPY_MODEL_117468099dbc42fdaafc08207eaac7ab", - "value": " 29.5M/29.5M [00:00<00:00, 36.5MB/s]" - } - }, - "98c5ce434cff454eaaa3f0fd3498183a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "9b01bcd6e5174be2af19f457047017c8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "9d2b6eabf7e14436b72bbf374b4a2a0a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_b5d7cb5a6157449a850ef0e12e3d3eb7", - "IPY_MODEL_c245d316bf9e44dabe5bfd1e47fc8d2e", - "IPY_MODEL_963cf422ca894d82b0dd94c6165d41bf" - ], - "layout": "IPY_MODEL_78d0e2aa93674bbeb42bff87a23cce9b" - } - }, - "ad1fb86cc1f94fd9911eda03cf4a3783": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "aef4172d916f40b0ab4ed09104e10f24": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "b09b2690894749339a9172e5ad0a9b75": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b5d7cb5a6157449a850ef0e12e3d3eb7": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_12c6f1180eeb4e9eb9037ea5dd24ec8e", - "placeholder": "​", - "style": "IPY_MODEL_017a81d7160240a398947545963856f5", - "value": "Generating validation split: 100%" - } - }, - "b77fe05bbcf84cdc8ef85b264ccd35f6": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b8c0c8aaac0d4032bf5c673a43d084ab": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b90d660ca8584ba1815a3c66b420c079": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "ba5e6ca09f174ef3a348453cf5cfc24a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_626ef2f811ae4e119a0e85cebe92b91d", - "max": 36030, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_aef4172d916f40b0ab4ed09104e10f24", - "value": 36030 - } - }, - "bfd75a39f0154c30adbaad1e2ca0f1e2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "c06f9a090fb54c74b947634bf6d11fa8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_82991dcc80f14af9bd2e95f705980676", - "IPY_MODEL_cd832e3842b945aabbb327856053f261", - "IPY_MODEL_93ee645d54f34acdb0d15092d4a6f0d1" - ], - "layout": "IPY_MODEL_b77fe05bbcf84cdc8ef85b264ccd35f6" - } - }, - "c245d316bf9e44dabe5bfd1e47fc8d2e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_1cf8eeb8d81c4e8a8e95dd43296a78b9", - "max": 30, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_5b0b5a3f79e94c51aae48fe0dd34ba0e", - "value": 30 - } - }, - "c452ccbf47a44073aee710175f707a7d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "c788d4e9e1e24dca9b6503689df9b631": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_d1587e2144bf46299c1bdec3ea96e4e7", - "IPY_MODEL_500a072c09da41759cb2c942a16d8429", - "IPY_MODEL_9785009392934e3bbb229e8781667cbc" - ], - "layout": "IPY_MODEL_84570fe2c2a54a068fb9b8cbc8b041a1" - } - }, - "ca015c1a0c1449e68edb282462435a3f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "cab80632b7564a9eb59583e09573c1ee": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "cbed38801163438d891879b756f5baab": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "cd832e3842b945aabbb327856053f261": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2932b06afde9468a976eb6bfb072b80e", - "max": 470745176, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d027c807ddc04f89bec41dc05fde7718", - "value": 470745176 - } - }, - "ce5019b36cde44c58c5f596dbb59a2f8": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d027c807ddc04f89bec41dc05fde7718": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "d1587e2144bf46299c1bdec3ea96e4e7": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_f9e579c58e3f4ae0bbb721dffa33bf0a", - "placeholder": "​", - "style": "IPY_MODEL_737116977f474ec0b68d88a40fd1086c", - "value": "dev-00000-of-00001.parquet: 100%" - } - }, - "d1f32499fa3f4795b92361637e23a9bb": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "d2029292327b488db02fd123ee2b75af": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d56e218958a041e286e80f24e400ab0b": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "da57445f98e7427589962836c2b4287e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "dec6399e2c5341aead66e1674d3e6c72": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e17d286a965a49cfb8d5bf885865cb1e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e51d501e2f994baba40345ad632eabee": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_7c4d1de626784a59a7e0a33c24086186", - "placeholder": "​", - "style": "IPY_MODEL_21cf0e35ecd845a8b5e7c5ce241cf177", - "value": " 287/287 [00:23<00:00, 12.48 examples/s]" - } - }, - "e6d6e516cd03452297d80c36376855dd": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f0dfeee2a8d64dedbc8ef55ad4e69932": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_201bd914f9884e46b8e6df9d9900a6e8", - "max": 5, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f53b7ada01084e73bba6e14a95e2a534", - "value": 5 - } - }, - "f255707788704a76bd1651f26a22402d": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f53b7ada01084e73bba6e14a95e2a534": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "f5b34a743ce54fb591f25b04a2651d65": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f9e579c58e3f4ae0bbb721dffa33bf0a": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "fa4800a506ac480984d58933580df086": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "fb644d47049f495397d0e60597c86ea3": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_3d0344a9cc744e369da1b6b7ea1b3be8", - "max": 165333397, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_c452ccbf47a44073aee710175f707a7d", - "value": 165333397 - } - }, - "fdefb51ad4c4418b98c5826126558011": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "fe7553b513954cc68c427b5d9d260b33": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_179d41b80dc841e8a440482516b8bca5", - "max": 461411018, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_22b1ecd2eff14770bcfb0c62d3d4213f", - "value": 461411018 - } - }, - "feb82e061ee44283b4a46be858ef4cd7": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_78a2d2d4ee3f42f3be42ef4baa298561", - "IPY_MODEL_ba5e6ca09f174ef3a348453cf5cfc24a", - "IPY_MODEL_74b58e4647644c9daf9af488942fdaf4" - ], - "layout": "IPY_MODEL_d56e218958a041e286e80f24e400ab0b" - } - } - } } }, "nbformat": 4, diff --git a/docs/readme.md b/docs/readme.md index b88a4738d..d84dbe6eb 100644 --- a/docs/readme.md +++ b/docs/readme.md @@ -3,10 +3,10 @@ Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html). ## Render locally + +From the llama-stack root directory, run the following command to render the docs locally: ```bash -pip install -r requirements.txt -cd docs -python -m sphinx_autobuild source _build +uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all ``` You can open up the docs in your browser at http://localhost:8000 diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 6cd45c33b..000000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,16 +0,0 @@ -linkify -myst-parser --e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme -sphinx==8.1.3 -sphinx-copybutton -sphinx-design -sphinx-pdj-theme -sphinx-rtd-theme>=1.0.0 -sphinx-tabs -sphinx_autobuild -sphinx_rtd_dark_mode -sphinxcontrib-mermaid -sphinxcontrib-openapi -sphinxcontrib-redoc -sphinxcontrib-video -tomli diff --git a/docs/source/conf.py b/docs/source/conf.py index 501a923dd..6e59dbdfb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,11 @@ from docutils import nodes # Read version from pyproject.toml with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f: pypi_url = "https://pypi.org/pypi/llama-stack/json" - version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"] + headers = { + 'User-Agent': 'pip/23.0.1 (python 3.11)', # Mimic pip's user agent + 'Accept': 'application/json' + } + version_tag = json.loads(requests.get(pypi_url, headers=headers).text)["info"]["version"] print(f"{version_tag=}") # generate the full link including text and url here @@ -53,14 +57,6 @@ myst_enable_extensions = ["colon_fence"] html_theme = "sphinx_rtd_theme" html_use_relative_paths = True - -# html_theme = "sphinx_pdj_theme" -# html_theme_path = [sphinx_pdj_theme.get_html_theme_path()] - -# html_theme = "pytorch_sphinx_theme" -# html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] - - templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index d9b73c910..0dbabf8aa 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -338,6 +338,48 @@ INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit) INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK ``` +### Listing Distributions +Using the list command, you can view all existing Llama Stack distributions, including stacks built from templates, from scratch, or using custom configuration files. + +``` +llama stack list -h +usage: llama stack list [-h] + +list the build stacks + +options: + -h, --help show this help message and exit +``` + +Example Usage + +``` +llama stack list +``` + +### Removing a Distribution +Use the remove command to delete a distribution you've previously built. + +``` +llama stack rm -h +usage: llama stack rm [-h] [--all] [name] + +Remove the build stack + +positional arguments: + name Name of the stack to delete (default: None) + +options: + -h, --help show this help message and exit + --all, -a Delete all stacks (use with caution) (default: False) +``` + +Example +``` +llama stack rm llamastack-test +``` + +To keep your environment organized and avoid clutter, consider using `llama stack list` to review old or unused distributions and `llama stack rm ` to delete them when they’re no longer needed. ### Troubleshooting diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index b62227a84..de99b6576 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -118,11 +118,6 @@ server: port: 8321 # Port to listen on (default: 8321) tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS - auth: # Optional: Authentication configuration - provider_type: "kubernetes" # Type of auth provider - config: # Provider-specific configuration - api_server_url: "https://kubernetes.default.svc" - ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate ``` ### Authentication Configuration @@ -135,7 +130,7 @@ Authorization: Bearer The server supports multiple authentication providers: -#### Kubernetes Provider +#### OAuth 2.0/OpenID Connect Provider with Kubernetes The Kubernetes cluster must be configured to use a service account for authentication. @@ -146,14 +141,67 @@ kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --se kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token ``` -Validates tokens against the Kubernetes API server: +Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests +and that the correct RoleBinding is created to allow the service account to access the necessary +resources. If that is not the case, you can create a RoleBinding for the service account to access +the necessary resources: + +```yaml +# allow-anonymous-openid.yaml +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: allow-anonymous-openid +rules: +- nonResourceURLs: ["/openid/v1/jwks"] + verbs: ["get"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: allow-anonymous-openid +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: allow-anonymous-openid +subjects: +- kind: User + name: system:anonymous + apiGroup: rbac.authorization.k8s.io +``` + +And then apply the configuration: +```bash +kubectl apply -f allow-anonymous-openid.yaml +``` + +Validates tokens against the Kubernetes API server through the OIDC provider: ```yaml server: auth: - provider_type: "kubernetes" + provider_type: "oauth2_token" config: - api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server - ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate + jwks: + uri: "https://kubernetes.default.svc" + key_recheck_period: 3600 + tls_cafile: "/path/to/ca.crt" + issuer: "https://kubernetes.default.svc" + audience: "https://kubernetes.default.svc" +``` + +To find your cluster's audience, run: +```bash +kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud +``` + +For the issuer, you can use the OIDC provider's URL: +```bash +kubectl get --raw /.well-known/openid-configuration| jq .issuer +``` + +For the tls_cafile, you can use the CA certificate of the OIDC provider: +```bash +kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}' ``` The provider extracts user information from the JWT token: @@ -208,6 +256,80 @@ And must respond with: If no access attributes are returned, the token is used as a namespace. +### Quota Configuration + +The `quota` section allows you to enable server-side request throttling for both +authenticated and anonymous clients. This is useful for preventing abuse, enforcing +fairness across tenants, and controlling infrastructure costs without requiring +client-side rate limiting or external proxies. + +Quotas are disabled by default. When enabled, each client is tracked using either: + +* Their authenticated `client_id` (derived from the Bearer token), or +* Their IP address (fallback for anonymous requests) + +Quota state is stored in a SQLite-backed key-value store, and rate limits are applied +within a configurable time window (currently only `day` is supported). + +#### Example + +```yaml +server: + quota: + kvstore: + type: sqlite + db_path: ./quotas.db + anonymous_max_requests: 100 + authenticated_max_requests: 1000 + period: day +``` + +#### Configuration Options + +| Field | Description | +| ---------------------------- | -------------------------------------------------------------------------- | +| `kvstore` | Required. Backend storage config for tracking request counts. | +| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. | +| `kvstore.db_path` | File path to the SQLite database. | +| `anonymous_max_requests` | Max requests per period for unauthenticated clients. | +| `authenticated_max_requests` | Max requests per period for authenticated clients. | +| `period` | Time window for quota enforcement. Only `"day"` is supported. | + +> Note: if `authenticated_max_requests` is set but no authentication provider is +configured, the server will fall back to applying `anonymous_max_requests` to all +clients. + +#### Example with Authentication Enabled + +```yaml +server: + port: 8321 + auth: + provider_type: custom + config: + endpoint: https://auth.example.com/validate + quota: + kvstore: + type: sqlite + db_path: ./quotas.db + anonymous_max_requests: 100 + authenticated_max_requests: 1000 + period: day +``` + +If a client exceeds their limit, the server responds with: + +```http +HTTP/1.1 429 Too Many Requests +Content-Type: application/json + +{ + "error": { + "message": "Quota exceeded" + } +} +``` + ## Extending to handle Safety Configuring Safety can be a little involved so it is instructive to go through an example. diff --git a/docs/source/distributions/self_hosted_distro/sambanova.md b/docs/source/distributions/self_hosted_distro/sambanova.md index aaa8fd3cc..bb4842362 100644 --- a/docs/source/distributions/self_hosted_distro/sambanova.md +++ b/docs/source/distributions/self_hosted_distro/sambanova.md @@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p |-----|-------------| | agents | `inline::meta-reference` | | inference | `remote::sambanova`, `inline::sentence-transformers` | -| safety | `inline::llama-guard` | +| safety | `remote::sambanova` | | telemetry | `inline::meta-reference` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | @@ -48,33 +48,44 @@ The following models are available by default: ### Prerequisite: API Keys -Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/). +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup). ## Running Llama Stack with SambaNova You can do this via Conda (build code) or Docker which has a pre-built image. -### Via Docker -This method allows you to get started quickly without having to build the distribution code. +### Via Docker ```bash LLAMA_STACK_PORT=8321 +llama stack build --template sambanova --image-type container docker run \ -it \ - --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - llamastack/distribution-sambanova \ + -v ~/.llama:/root/.llama \ + distribution-sambanova \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` + +### Via Venv + +```bash +llama stack build --template sambanova --image-type venv +llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` + + ### Via Conda ```bash llama stack build --template sambanova --image-type conda -llama stack run ./run.yaml \ +llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` diff --git a/docs/source/providers/vector_io/sqlite-vec.md b/docs/source/providers/vector_io/sqlite-vec.md index 43d10c751..49ba659f7 100644 --- a/docs/source/providers/vector_io/sqlite-vec.md +++ b/docs/source/providers/vector_io/sqlite-vec.md @@ -66,6 +66,25 @@ To use sqlite-vec in your Llama Stack project, follow these steps: 2. Configure your Llama Stack project to use SQLite-Vec. 3. Start storing and querying vectors. +## Supported Search Modes + +The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes. + +When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in +`RAGQueryConfig`. For example: + +```python +from llama_stack.apis.tool_runtime.rag import RAGQueryConfig + +query_config = RAGQueryConfig(max_chunks=6, mode="vector") + +results = client.tool_runtime.rag_tool.query( + vector_db_ids=[vector_db_id], + content="what is torchtune", + query_config=query_config, +) +``` + ## Installation You can install SQLite-Vec using pip: diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index b2f85336c..b79c512b8 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent -from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.common.responses import Order, PaginatedResponse from llama_stack.apis.inference import ( CompletionMessage, ResponseFormat, @@ -31,6 +31,8 @@ from llama_stack.apis.tools import ToolDef from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from .openai_responses import ( + ListOpenAIResponseInputItem, + ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, @@ -579,14 +581,14 @@ class Agents(Protocol): # # Both of these APIs are inherently stateful. - @webmethod(route="/openai/v1/responses/{id}", method="GET") + @webmethod(route="/openai/v1/responses/{response_id}", method="GET") async def get_openai_response( self, - id: str, + response_id: str, ) -> OpenAIResponseObject: """Retrieve an OpenAI response by its ID. - :param id: The ID of the OpenAI response to retrieve. + :param response_id: The ID of the OpenAI response to retrieve. :returns: An OpenAIResponseObject. """ ... @@ -596,6 +598,7 @@ class Agents(Protocol): self, input: str | list[OpenAIResponseInput], model: str, + instructions: str | None = None, previous_response_id: str | None = None, store: bool | None = True, stream: bool | None = False, @@ -610,3 +613,43 @@ class Agents(Protocol): :returns: An OpenAIResponseObject. """ ... + + @webmethod(route="/openai/v1/responses", method="GET") + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + """List all OpenAI responses. + + :param after: The ID of the last response to return. + :param limit: The number of responses to return. + :param model: The model to filter responses by. + :param order: The order to sort responses by when sorted by created_at ('asc' or 'desc'). + :returns: A ListOpenAIResponseObject. + """ + ... + + @webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET") + async def list_openai_response_input_items( + self, + response_id: str, + after: str | None = None, + before: str | None = None, + include: list[str] | None = None, + limit: int | None = 20, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseInputItem: + """List input items for a given OpenAI response. + + :param response_id: The ID of the response to retrieve input items for. + :param after: An item ID to list items after, used for pagination. + :param before: An item ID to list items before, used for pagination. + :param include: Additional fields to include in the response. + :param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20. + :param order: The order to return the input items in. Default is desc. + :returns: An ListOpenAIResponseInputItem. + """ + ... diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index dcf0c7f9c..6806e1d3f 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -10,6 +10,9 @@ from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type, register_schema +# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably +# take their YAML and generate this file automatically. Their YAML is available. + @json_schema_type class OpenAIResponseError(BaseModel): @@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): @json_schema_type class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): - arguments: str call_id: str name: str + arguments: str type: Literal["function_call"] = "function_call" + id: str | None = None + status: str | None = None + + +@json_schema_type +class OpenAIResponseOutputMessageMCPCall(BaseModel): id: str - status: str + type: Literal["mcp_call"] = "mcp_call" + arguments: str + name: str + server_label: str + error: str | None = None + output: str | None = None + + +class MCPListToolsTool(BaseModel): + input_schema: dict[str, Any] + name: str + description: str | None = None + + +@json_schema_type +class OpenAIResponseOutputMessageMCPListTools(BaseModel): + id: str + type: Literal["mcp_list_tools"] = "mcp_list_tools" + server_label: str + tools: list[MCPListToolsTool] OpenAIResponseOutput = Annotated[ - OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseMessage + | OpenAIResponseOutputMessageWebSearchToolCall + | OpenAIResponseOutputMessageFunctionToolCall + | OpenAIResponseOutputMessageMCPCall + | OpenAIResponseOutputMessageMCPListTools, Field(discriminator="type"), ] register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") @@ -117,6 +149,16 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel): type: Literal["response.created"] = "response.created" +@json_schema_type +class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel): + content_index: int + delta: str + item_id: str + output_index: int + sequence_number: int + type: Literal["response.output_text.delta"] = "response.output_text.delta" + + @json_schema_type class OpenAIResponseObjectStreamResponseCompleted(BaseModel): response: OpenAIResponseObject @@ -124,7 +166,9 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel): OpenAIResponseObjectStream = Annotated[ - OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseCreated + | OpenAIResponseObjectStreamResponseOutputTextDelta + | OpenAIResponseObjectStreamResponseCompleted, Field(discriminator="type"), ] register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream") @@ -186,13 +230,50 @@ class OpenAIResponseInputToolFileSearch(BaseModel): # TODO: add filters +class ApprovalFilter(BaseModel): + always: list[str] | None = None + never: list[str] | None = None + + +class AllowedToolsFilter(BaseModel): + tool_names: list[str] | None = None + + +@json_schema_type +class OpenAIResponseInputToolMCP(BaseModel): + type: Literal["mcp"] = "mcp" + server_label: str + server_url: str + headers: dict[str, Any] | None = None + + require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never" + allowed_tools: list[str] | AllowedToolsFilter | None = None + + OpenAIResponseInputTool = Annotated[ - OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction, + OpenAIResponseInputToolWebSearch + | OpenAIResponseInputToolFileSearch + | OpenAIResponseInputToolFunction + | OpenAIResponseInputToolMCP, Field(discriminator="type"), ] register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") -class OpenAIResponseInputItemList(BaseModel): +class ListOpenAIResponseInputItem(BaseModel): data: list[OpenAIResponseInput] object: Literal["list"] = "list" + + +@json_schema_type +class OpenAIResponseObjectWithInput(OpenAIResponseObject): + input: list[OpenAIResponseInput] + + +@json_schema_type +class ListOpenAIResponseObject(BaseModel): + data: list[OpenAIResponseObjectWithInput] + has_more: bool + first_id: str + last_id: str + object: Literal["list"] = "list" diff --git a/llama_stack/apis/common/deployment_types.py b/llama_stack/apis/common/deployment_types.py deleted file mode 100644 index 4d01d7ad1..000000000 --- a/llama_stack/apis/common/deployment_types.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import Any - -from pydantic import BaseModel - -from llama_stack.apis.common.content_types import URL -from llama_stack.schema_utils import json_schema_type - - -@json_schema_type -class RestAPIMethod(Enum): - GET = "GET" - POST = "POST" - PUT = "PUT" - DELETE = "DELETE" - - -@json_schema_type -class RestAPIExecutionConfig(BaseModel): - url: URL - method: RestAPIMethod - params: dict[str, Any] | None = None - headers: dict[str, Any] | None = None - body: dict[str, Any] | None = None diff --git a/llama_stack/apis/common/responses.py b/llama_stack/apis/common/responses.py index b3bb5cb6b..5cb41e23d 100644 --- a/llama_stack/apis/common/responses.py +++ b/llama_stack/apis/common/responses.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Any from pydantic import BaseModel @@ -11,6 +12,11 @@ from pydantic import BaseModel from llama_stack.schema_utils import json_schema_type +class Order(Enum): + asc = "asc" + desc = "desc" + + @json_schema_type class PaginatedResponse(BaseModel): """A generic paginated response that follows a simple format. diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7f8b20952..e79dc6d94 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -19,6 +19,7 @@ from pydantic import BaseModel, Field, field_validator from typing_extensions import TypedDict from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.responses import Order from llama_stack.apis.models import Model from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.models.llama.datatypes import ( @@ -833,11 +834,6 @@ class ListOpenAIChatCompletionResponse(BaseModel): object: Literal["list"] = "list" -class Order(Enum): - asc = "asc" - desc = "desc" - - @runtime_checkable @trace_protocol class InferenceProvider(Protocol): diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index de3e4c62c..1e3542f74 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -76,6 +76,7 @@ class RAGQueryConfig(BaseModel): :param chunk_template: Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n" + :param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector". """ # This config defines how a query is generated using the messages @@ -84,6 +85,7 @@ class RAGQueryConfig(BaseModel): max_tokens_in_context: int = 4096 max_chunks: int = 5 chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" + mode: str | None = None @field_validator("chunk_template") def validate_chunk_template(cls, v: str) -> str: diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 2f62b0ba1..0c8d47edf 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -27,18 +27,10 @@ class ToolParameter(BaseModel): default: Any | None = None -@json_schema_type -class ToolHost(Enum): - distribution = "distribution" - client = "client" - model_context_protocol = "model_context_protocol" - - @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool] = ResourceType.tool toolgroup_id: str - tool_host: ToolHost description: str parameters: list[ToolParameter] metadata: dict[str, Any] | None = None @@ -76,8 +68,8 @@ class ToolInvocationResult(BaseModel): class ToolStore(Protocol): - def get_tool(self, tool_name: str) -> Tool: ... - def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ... + async def get_tool(self, tool_name: str) -> Tool: ... + async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ... class ListToolGroupsResponse(BaseModel): diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index 09c753776..b96842119 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -9,6 +9,7 @@ import asyncio import json import os import shutil +import sys from dataclasses import dataclass from datetime import datetime, timezone from functools import partial @@ -377,14 +378,15 @@ def _meta_download( downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads) asyncio.run(downloader.download_all(tasks)) - cprint(f"\nSuccessfully downloaded model to {output_dir}", "green") + cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr) cprint( f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}", - "white", + file=sys.stderr, ) cprint( f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}", - "yellow", + color="yellow", + file=sys.stderr, ) diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 4c6f82d1b..f6f72946a 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -79,6 +79,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates", color="red", + file=sys.stderr, ) sys.exit(1) build_config = available_templates[args.template] @@ -88,6 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}", color="red", + file=sys.stderr, ) sys.exit(1) elif args.providers: @@ -97,6 +99,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( "Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2", color="red", + file=sys.stderr, ) sys.exit(1) api, provider = api_provider.split("=") @@ -105,6 +108,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"{api} is not a valid API.", color="red", + file=sys.stderr, ) sys.exit(1) if provider in providers_for_api: @@ -113,6 +117,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"{provider} is not a valid provider for the {api} API.", color="red", + file=sys.stderr, ) sys.exit(1) distribution_spec = DistributionSpec( @@ -123,6 +128,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"Please specify a image-type (container | conda | venv) for {args.template}", color="red", + file=sys.stderr, ) sys.exit(1) @@ -151,12 +157,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`", color="yellow", + file=sys.stderr, ) image_name = f"llamastack-{name}" else: cprint( f"Using conda environment {image_name}", color="green", + file=sys.stderr, ) else: image_name = f"llamastack-{name}" @@ -169,9 +177,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None: """, ), color="green", + file=sys.stderr, ) - print("Tip: use to see options for the providers.\n") + cprint("Tip: use to see options for the providers.\n", color="green", file=sys.stderr) providers = dict() for api, providers_for_api in get_provider_registry().items(): @@ -213,6 +222,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"Could not parse config file {args.config}: {e}", color="red", + file=sys.stderr, ) sys.exit(1) @@ -239,22 +249,25 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint( f"Error building stack: {exc}", color="red", + file=sys.stderr, ) - cprint("Stack trace:", color="red") + cprint("Stack trace:", color="red", file=sys.stderr) traceback.print_exc() sys.exit(1) + if run_config is None: cprint( "Run config path is empty", color="red", + file=sys.stderr, ) sys.exit(1) if args.run: config_dict = yaml.safe_load(run_config.read_text()) config = parse_and_maybe_upgrade_config(config_dict) - if not os.path.exists(config.external_providers_dir): - os.makedirs(config.external_providers_dir, exist_ok=True) + if config.external_providers_dir and not config.external_providers_dir.exists(): + config.external_providers_dir.mkdir(exist_ok=True) run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config]) run_command(run_args) @@ -304,6 +317,7 @@ def _generate_run_config( cprint( f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping", color="yellow", + file=sys.stderr, ) # Set config_type to None to avoid UnboundLocalError config_type = None @@ -331,10 +345,7 @@ def _generate_run_config( # For non-container builds, the run.yaml is generated at the very end of the build process so it # makes sense to display this message if build_config.image_type != LlamaStackImageType.CONTAINER.value: - cprint( - f"You can now run your stack with `llama stack run {run_config_file}`", - color="green", - ) + cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr) return run_config_file @@ -372,7 +383,7 @@ def _run_stack_build_command_from_build_config( # Generate the run.yaml so it can be included in the container image with the proper entrypoint # Only do this if we're building a container image and we're not using a template if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path: - cprint("Generating run.yaml file", color="green") + cprint("Generating run.yaml file", color="yellow", file=sys.stderr) run_config_file = _generate_run_config(build_config, build_dir, image_name) with open(build_file_path, "w") as f: @@ -396,11 +407,13 @@ def _run_stack_build_command_from_build_config( run_config_file = build_dir / f"{template_name}-run.yaml" shutil.copy(path, run_config_file) - cprint("Build Successful!", color="green") - cprint("You can find the newly-built template here: " + colored(template_path, "light_blue")) + cprint("Build Successful!", color="green", file=sys.stderr) + cprint(f"You can find the newly-built template here: {template_path}", color="light_blue", file=sys.stderr) cprint( "You can run the new Llama Stack distro via: " - + colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue") + + colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue"), + color="green", + file=sys.stderr, ) return template_path else: diff --git a/llama_stack/cli/stack/list_stacks.py b/llama_stack/cli/stack/list_stacks.py new file mode 100644 index 000000000..2ea0fdeea --- /dev/null +++ b/llama_stack/cli/stack/list_stacks.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +from pathlib import Path + +from llama_stack.cli.subcommand import Subcommand +from llama_stack.cli.table import print_table + + +class StackListBuilds(Subcommand): + """List built stacks in .llama/distributions directory""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "list", + prog="llama stack list", + description="list the build stacks", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._list_stack_command) + + def _get_distribution_dirs(self) -> dict[str, Path]: + """Return a dictionary of distribution names and their paths""" + distributions = {} + dist_dir = Path.home() / ".llama" / "distributions" + + if dist_dir.exists(): + for stack_dir in dist_dir.iterdir(): + if stack_dir.is_dir(): + distributions[stack_dir.name] = stack_dir + return distributions + + def _list_stack_command(self, args: argparse.Namespace) -> None: + distributions = self._get_distribution_dirs() + + if not distributions: + print("No stacks found in ~/.llama/distributions") + return + + headers = ["Stack Name", "Path"] + headers.extend(["Build Config", "Run Config"]) + rows = [] + for name, path in distributions.items(): + row = [name, str(path)] + # Check for build and run config files + build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No" + run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No" + row.extend([build_config, run_config]) + rows.append(row) + print_table(rows, headers, separate_rows=True) diff --git a/llama_stack/cli/stack/remove.py b/llama_stack/cli/stack/remove.py new file mode 100644 index 000000000..a1796941e --- /dev/null +++ b/llama_stack/cli/stack/remove.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import shutil +import sys +from pathlib import Path + +from termcolor import cprint + +from llama_stack.cli.subcommand import Subcommand +from llama_stack.cli.table import print_table + + +class StackRemove(Subcommand): + """Remove the build stack""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "rm", + prog="llama stack rm", + description="Remove the build stack", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._remove_stack_build_command) + + def _add_arguments(self) -> None: + self.parser.add_argument( + "name", + type=str, + nargs="?", + help="Name of the stack to delete", + ) + self.parser.add_argument( + "--all", + "-a", + action="store_true", + help="Delete all stacks (use with caution)", + ) + + def _get_distribution_dirs(self) -> dict[str, Path]: + """Return a dictionary of distribution names and their paths""" + distributions = {} + dist_dir = Path.home() / ".llama" / "distributions" + + if dist_dir.exists(): + for stack_dir in dist_dir.iterdir(): + if stack_dir.is_dir(): + distributions[stack_dir.name] = stack_dir + return distributions + + def _list_stacks(self) -> None: + """Display available stacks in a table""" + distributions = self._get_distribution_dirs() + if not distributions: + cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr) + sys.exit(1) + + headers = ["Stack Name", "Path"] + rows = [[name, str(path)] for name, path in distributions.items()] + print_table(rows, headers, separate_rows=True) + + def _remove_stack_build_command(self, args: argparse.Namespace) -> None: + distributions = self._get_distribution_dirs() + + if args.all: + confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower() + if confirm != "yes-i-really-want": + cprint("Deletion cancelled.", color="green", file=sys.stderr) + return + + for name, path in distributions.items(): + try: + shutil.rmtree(path) + cprint(f"Deleted stack: {name}", color="green", file=sys.stderr) + except Exception as e: + cprint( + f"Failed to delete stack {name}: {e}", + color="red", + file=sys.stderr, + ) + sys.exit(1) + + if not args.name: + self._list_stacks() + if not args.name: + return + + if args.name not in distributions: + self._list_stacks() + cprint( + f"Stack not found: {args.name}", + color="red", + file=sys.stderr, + ) + sys.exit(1) + + stack_path = distributions[args.name] + + confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower() + if confirm != "y": + cprint("Deletion cancelled.", color="green", file=sys.stderr) + return + + try: + shutil.rmtree(stack_path) + cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr) + except Exception as e: + cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr) + sys.exit(1) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 4a44e0366..27745edac 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -6,6 +6,7 @@ import argparse import os +import subprocess from pathlib import Path from llama_stack.cli.stack.utils import ImageType @@ -60,6 +61,11 @@ class StackRun(Subcommand): help="Image Type used during the build. This can be either conda or container or venv.", choices=[e.value for e in ImageType], ) + self.parser.add_argument( + "--enable-ui", + action="store_true", + help="Start the UI server", + ) # If neither image type nor image name is provided, but at the same time # the current environment has conda breadcrumbs, then assume what the user @@ -83,6 +89,8 @@ class StackRun(Subcommand): from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.exec import formulate_run_args, run_command + if args.enable_ui: + self._start_ui_development_server(args.port) image_type, image_name = self._get_image_type_and_name(args) # Check if config is required based on image type @@ -170,3 +178,44 @@ class StackRun(Subcommand): run_args.extend(["--env", f"{key}={value}"]) run_command(run_args) + + def _start_ui_development_server(self, stack_server_port: int): + logger.info("Attempting to start UI development server...") + # Check if npm is available + npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False) + if npm_check.returncode != 0: + logger.warning( + f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}" + ) + return + + ui_dir = REPO_ROOT / "llama_stack" / "ui" + logs_dir = Path("~/.llama/ui/logs").expanduser() + try: + # Create logs directory if it doesn't exist + logs_dir.mkdir(parents=True, exist_ok=True) + + ui_stdout_log_path = logs_dir / "stdout.log" + ui_stderr_log_path = logs_dir / "stderr.log" + + # Open log files in append mode + stdout_log_file = open(ui_stdout_log_path, "a") + stderr_log_file = open(ui_stderr_log_path, "a") + + process = subprocess.Popen( + ["npm", "run", "dev"], + cwd=str(ui_dir), + stdout=stdout_log_file, + stderr=stderr_log_file, + env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"}, + ) + logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.") + logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}") + logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}") + + except FileNotFoundError: + logger.error( + "Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH." + ) + except Exception as e: + logger.error(f"Failed to start UI development server in {ui_dir}: {e}") diff --git a/llama_stack/cli/stack/stack.py b/llama_stack/cli/stack/stack.py index ccf1a5ffc..3aff78e23 100644 --- a/llama_stack/cli/stack/stack.py +++ b/llama_stack/cli/stack/stack.py @@ -7,12 +7,14 @@ import argparse from importlib.metadata import version +from llama_stack.cli.stack.list_stacks import StackListBuilds from llama_stack.cli.stack.utils import print_subcommand_description from llama_stack.cli.subcommand import Subcommand from .build import StackBuild from .list_apis import StackListApis from .list_providers import StackListProviders +from .remove import StackRemove from .run import StackRun @@ -41,5 +43,6 @@ class StackParser(Subcommand): StackListApis.create(subparsers) StackListProviders.create(subparsers) StackRun.create(subparsers) - + StackRemove.create(subparsers) + StackListBuilds.create(subparsers) print_subcommand_description(self.parser, subparsers) diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 1d39063f0..072f9c425 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -6,6 +6,7 @@ import importlib.resources import logging +import sys from pathlib import Path from pydantic import BaseModel @@ -43,8 +44,20 @@ def get_provider_dependencies( # Extract providers based on config type if isinstance(config, DistributionTemplate): providers = config.providers + + # TODO: This is a hack to get the dependencies for internal APIs into build + # We should have a better way to do this by formalizing the concept of "internal" APIs + # and providers, with a way to specify dependencies for them. + run_configs = config.run_configs + additional_pip_packages: list[str] = [] + if run_configs: + for run_config in run_configs.values(): + run_config_ = run_config.run_config(name="", providers={}, container_image=None) + if run_config_.inference_store: + additional_pip_packages.extend(run_config_.inference_store.pip_packages) elif isinstance(config, BuildConfig): providers = config.distribution_spec.providers + additional_pip_packages = config.additional_pip_packages deps = [] registry = get_provider_registry(config) for api_str, provider_or_providers in providers.items(): @@ -72,6 +85,9 @@ def get_provider_dependencies( else: normal_deps.append(package) + if additional_pip_packages: + normal_deps.extend(additional_pip_packages) + return list(set(normal_deps)), list(set(special_deps)) @@ -80,10 +96,11 @@ def print_pip_install_help(config: BuildConfig): cprint( f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}", - "yellow", + color="yellow", + file=sys.stderr, ) for special_dep in special_deps: - cprint(f"uv pip install {special_dep}", "yellow") + cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr) print() diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index 9fde8a157..03e4fb051 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -6,6 +6,7 @@ import inspect import json +import sys from collections.abc import AsyncIterator from enum import Enum from typing import Any, Union, get_args, get_origin @@ -96,13 +97,13 @@ def create_api_client_class(protocol) -> type: try: data = json.loads(data) if "error" in data: - cprint(data, "red") + cprint(data, color="red", file=sys.stderr) continue yield parse_obj_as(return_type, data) except Exception as e: - print(f"Error with parsing or validation: {e}") - print(data) + cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr) + cprint(data, color="red", file=sys.stderr) def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict: webmethod, sig = self.routes[method_name] diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 446a88ca0..def7048c0 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -25,7 +25,8 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.datatypes import Api, ProviderSpec -from llama_stack.providers.utils.kvstore.config import KVStoreConfig +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" @@ -220,21 +221,38 @@ class LoggingConfig(BaseModel): class AuthProviderType(str, Enum): """Supported authentication provider types.""" - KUBERNETES = "kubernetes" + OAUTH2_TOKEN = "oauth2_token" CUSTOM = "custom" class AuthenticationConfig(BaseModel): provider_type: AuthProviderType = Field( ..., - description="Type of authentication provider (e.g., 'kubernetes', 'custom')", + description="Type of authentication provider", ) - config: dict[str, str] = Field( + config: dict[str, Any] = Field( ..., description="Provider-specific configuration", ) +class AuthenticationRequiredError(Exception): + pass + + +class QuotaPeriod(str, Enum): + DAY = "day" + + +class QuotaConfig(BaseModel): + kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)") + anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period") + authenticated_max_requests: int = Field( + default=1000, description="Max requests for authenticated clients per period" + ) + period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -262,6 +280,10 @@ class ServerConfig(BaseModel): default=None, description="The host the server should listen on", ) + quota: QuotaConfig | None = Field( + default=None, + description="Per client quota request configuration", + ) class StackRunConfig(BaseModel): @@ -297,6 +319,13 @@ Configuration for the persistence store used by the distribution registry. If no a default SQLite store will be used.""", ) + inference_store: SqlStoreConfig | None = Field( + default=None, + description=""" +Configuration for the persistence store used by the inference API. If not specified, +a default SQLite store will be used.""", + ) + # registry of "resources" in the distribution models: list[ModelInput] = Field(default_factory=list) shields: list[ShieldInput] = Field(default_factory=list) @@ -345,6 +374,10 @@ class BuildConfig(BaseModel): description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. " "pip_packages MUST contain the provider package name.", ) + additional_pip_packages: list[str] = Field( + default_factory=list, + description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.", + ) @field_validator("external_providers_dir") @classmethod diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 23f644ec6..3321ec291 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -31,7 +31,7 @@ async def get_provider_impl(config, deps): class DistributionInspectImpl(Inspect): - def __init__(self, config, deps): + def __init__(self, config: DistributionInspectConfig, deps): self.config = config self.deps = deps @@ -39,22 +39,36 @@ class DistributionInspectImpl(Inspect): pass async def list_routes(self) -> ListRoutesResponse: - run_config = self.config.run_config + run_config: StackRunConfig = self.config.run_config ret = [] all_endpoints = get_all_api_endpoints() for api, endpoints in all_endpoints.items(): - providers = run_config.providers.get(api.value, []) - ret.extend( - [ - RouteInfo( - route=e.route, - method=e.method, - provider_types=[p.provider_type for p in providers], + # Always include provider and inspect APIs, filter others based on run config + if api.value in ["providers", "inspect"]: + ret.extend( + [ + RouteInfo( + route=e.route, + method=e.method, + provider_types=[], # These APIs don't have "real" providers - they're internal to the stack + ) + for e in endpoints + ] + ) + else: + providers = run_config.providers.get(api.value, []) + if providers: # Only process if there are providers for this API + ret.extend( + [ + RouteInfo( + route=e.route, + method=e.method, + provider_types=[p.provider_type for p in providers], + ) + for e in endpoints + ] ) - for e in endpoints - ] - ) return ListRoutesResponse(data=ret) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 34e041cbe..3cd2d1728 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -9,6 +9,7 @@ import inspect import json import logging import os +import sys from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path @@ -210,10 +211,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.endpoint_impls = None self.impls = await construct_stack(self.config, self.custom_provider_registry) except ModuleNotFoundError as _e: - cprint(_e.msg, "red") + cprint(_e.msg, color="red", file=sys.stderr) cprint( "Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n", - "yellow", + color="yellow", + file=sys.stderr, ) if self.config_path_or_template_name.endswith(".yaml"): # Convert Provider objects to their types @@ -234,7 +236,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): cprint( f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n", "yellow", + file=sys.stderr, ) + cprint( + "Please check your internet connection and try again.", + "red", + file=sys.stderr, + ) raise _e if Api.telemetry in self.impls: @@ -261,9 +269,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): raise ValueError("Client not initialized") # Create headers with provider data if available - headers = {} + headers = options.headers or {} if self.provider_data: - headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data) + keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"] + if all(key not in headers for key in keys): + headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data) # Use context manager for provider data with request_provider_data_context(headers): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 257c495c3..b7c7cb87f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import ( RemoteProviderSpec, ScoringFunctionsProtocolPrivate, ShieldsProtocolPrivate, - ToolsProtocolPrivate, + ToolGroupsProtocolPrivate, VectorDBsProtocolPrivate, ) @@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]: def additional_protocols_map() -> dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), - Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), + Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups), Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), @@ -140,7 +140,7 @@ async def resolve_impls( sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) - return await instantiate_providers(sorted_providers, router_apis, dist_registry) + return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config) def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: @@ -243,7 +243,10 @@ def sort_providers_by_deps( async def instantiate_providers( - sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry + sorted_providers: list[tuple[str, ProviderWithSpec]], + router_apis: set[Api], + dist_registry: DistributionRegistry, + run_config: StackRunConfig, ) -> dict: """Instantiates providers asynchronously while managing dependencies.""" impls: dict[Api, Any] = {} @@ -258,7 +261,7 @@ async def instantiate_providers( if isinstance(provider.spec, RoutingTableProviderSpec): inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] - impl = await instantiate_provider(provider, deps, inner_impls, dist_registry) + impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config) if api_str.startswith("inner-"): inner_impls_by_provider_id[api_str][provider.provider_id] = impl @@ -308,6 +311,7 @@ async def instantiate_provider( deps: dict[Api, Any], inner_impls: dict[str, Any], dist_registry: DistributionRegistry, + run_config: StackRunConfig, ): provider_spec = provider.spec if not hasattr(provider_spec, "module"): @@ -327,7 +331,7 @@ async def instantiate_provider( method = "get_auto_router_impl" config = None - args = [provider_spec.api, deps[provider_spec.routing_table_api], deps] + args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config] elif isinstance(provider_spec, RoutingTableProviderSpec): method = "get_routing_table_impl" diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index cd2a296f2..1358d5812 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -7,18 +7,10 @@ from typing import Any from llama_stack.distribution.datatypes import RoutedProtocol +from llama_stack.distribution.stack import StackRunConfig from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable - -from .routing_tables import ( - BenchmarksRoutingTable, - DatasetsRoutingTable, - ModelsRoutingTable, - ScoringFunctionsRoutingTable, - ShieldsRoutingTable, - ToolGroupsRoutingTable, - VectorDBsRoutingTable, -) +from llama_stack.providers.utils.inference.inference_store import InferenceStore async def get_routing_table_impl( @@ -27,6 +19,14 @@ async def get_routing_table_impl( _deps, dist_registry: DistributionRegistry, ) -> Any: + from ..routing_tables.benchmarks import BenchmarksRoutingTable + from ..routing_tables.datasets import DatasetsRoutingTable + from ..routing_tables.models import ModelsRoutingTable + from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable + from ..routing_tables.shields import ShieldsRoutingTable + from ..routing_tables.toolgroups import ToolGroupsRoutingTable + from ..routing_tables.vector_dbs import VectorDBsRoutingTable + api_to_tables = { "vector_dbs": VectorDBsRoutingTable, "models": ModelsRoutingTable, @@ -45,16 +45,15 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any: - from .routers import ( - DatasetIORouter, - EvalRouter, - InferenceRouter, - SafetyRouter, - ScoringRouter, - ToolRuntimeRouter, - VectorIORouter, - ) +async def get_auto_router_impl( + api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig +) -> Any: + from .datasets import DatasetIORouter + from .eval_scoring import EvalRouter, ScoringRouter + from .inference import InferenceRouter + from .safety import SafetyRouter + from .tool_runtime import ToolRuntimeRouter + from .vector_io import VectorIORouter api_to_routers = { "vector_io": VectorIORouter, @@ -76,6 +75,12 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict if dep_api in deps: api_to_dep_impl[dep_name] = deps[dep_api] + # TODO: move pass configs to routers instead + if api == Api.inference and run_config.inference_store: + inference_store = InferenceStore(run_config.inference_store) + await inference_store.initialize() + api_to_dep_impl["store"] = inference_store + impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/datasets.py b/llama_stack/distribution/routers/datasets.py new file mode 100644 index 000000000..6f28756c9 --- /dev/null +++ b/llama_stack/distribution/routers/datasets.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import DatasetPurpose, DataSource +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class DatasetIORouter(DatasetIO): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing DatasetIORouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("DatasetIORouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("DatasetIORouter.shutdown") + pass + + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, + ) -> None: + logger.debug( + f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", + ) + await self.routing_table.register_dataset( + purpose=purpose, + source=source, + metadata=metadata, + dataset_id=dataset_id, + ) + + async def iterrows( + self, + dataset_id: str, + start_index: int | None = None, + limit: int | None = None, + ) -> PaginatedResponse: + logger.debug( + f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", + ) + return await self.routing_table.get_provider_impl(dataset_id).iterrows( + dataset_id=dataset_id, + start_index=start_index, + limit=limit, + ) + + async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: + logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") + return await self.routing_table.get_provider_impl(dataset_id).append_rows( + dataset_id=dataset_id, + rows=rows, + ) diff --git a/llama_stack/distribution/routers/eval_scoring.py b/llama_stack/distribution/routers/eval_scoring.py new file mode 100644 index 000000000..fd0bb90a7 --- /dev/null +++ b/llama_stack/distribution/routers/eval_scoring.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringFnParams, +) +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class ScoringRouter(Scoring): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing ScoringRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("ScoringRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("ScoringRouter.shutdown") + pass + + async def score_batch( + self, + dataset_id: str, + scoring_functions: dict[str, ScoringFnParams | None] = None, + save_results_dataset: bool = False, + ) -> ScoreBatchResponse: + logger.debug(f"ScoringRouter.score_batch: {dataset_id}") + res = {} + for fn_identifier in scoring_functions.keys(): + score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( + dataset_id=dataset_id, + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, + ) + res.update(score_response.results) + + if save_results_dataset: + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res, + ) + + async def score( + self, + input_rows: list[dict[str, Any]], + scoring_functions: dict[str, ScoringFnParams | None] = None, + ) -> ScoreResponse: + logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") + res = {} + # look up and map each scoring function to its provider impl + for fn_identifier in scoring_functions.keys(): + score_response = await self.routing_table.get_provider_impl(fn_identifier).score( + input_rows=input_rows, + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, + ) + res.update(score_response.results) + + return ScoreResponse(results=res) + + +class EvalRouter(Eval): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing EvalRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("EvalRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("EvalRouter.shutdown") + pass + + async def run_eval( + self, + benchmark_id: str, + benchmark_config: BenchmarkConfig, + ) -> Job: + logger.debug(f"EvalRouter.run_eval: {benchmark_id}") + return await self.routing_table.get_provider_impl(benchmark_id).run_eval( + benchmark_id=benchmark_id, + benchmark_config=benchmark_config, + ) + + async def evaluate_rows( + self, + benchmark_id: str, + input_rows: list[dict[str, Any]], + scoring_functions: list[str], + benchmark_config: BenchmarkConfig, + ) -> EvaluateResponse: + logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") + return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( + benchmark_id=benchmark_id, + input_rows=input_rows, + scoring_functions=scoring_functions, + benchmark_config=benchmark_config, + ) + + async def job_status( + self, + benchmark_id: str, + job_id: str, + ) -> Job: + logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") + return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + + async def job_cancel( + self, + benchmark_id: str, + job_id: str, + ) -> None: + logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") + await self.routing_table.get_provider_impl(benchmark_id).job_cancel( + benchmark_id, + job_id, + ) + + async def job_result( + self, + benchmark_id: str, + job_id: str, + ) -> EvaluateResponse: + logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") + return await self.routing_table.get_provider_impl(benchmark_id).job_result( + benchmark_id, + job_id, + ) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/inference.py similarity index 64% rename from llama_stack/distribution/routers/routers.py rename to llama_stack/distribution/routers/inference.py index 371d34904..f77b19302 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/inference.py @@ -14,14 +14,9 @@ from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToo from pydantic import Field, TypeAdapter from llama_stack.apis.common.content_types import ( - URL, InterleavedContent, InterleavedContentItem, ) -from llama_stack.apis.common.responses import PaginatedResponse -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import DatasetPurpose, DataSource -from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job from llama_stack.apis.inference import ( BatchChatCompletionResponse, BatchCompletionResponse, @@ -32,8 +27,11 @@ from llama_stack.apis.inference import ( EmbeddingsResponse, EmbeddingTaskType, Inference, + ListOpenAIChatCompletionResponse, LogProbConfig, Message, + OpenAICompletionWithInputMessages, + Order, ResponseFormat, SamplingParams, StopReason, @@ -51,89 +49,18 @@ from llama_stack.apis.inference.inference import ( OpenAIResponseFormatParam, ) from llama_stack.apis.models import Model, ModelType -from llama_stack.apis.safety import RunShieldResponse, Safety -from llama_stack.apis.scoring import ( - ScoreBatchResponse, - ScoreResponse, - Scoring, - ScoringFnParams, -) -from llama_stack.apis.shields import Shield from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry -from llama_stack.apis.tools import ( - ListToolDefsResponse, - RAGDocument, - RAGQueryConfig, - RAGQueryResult, - RAGToolRuntime, - ToolRuntime, -) -from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable +from llama_stack.providers.utils.inference.inference_store import InferenceStore +from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") -class VectorIORouter(VectorIO): - """Routes to an provider based on the vector db identifier""" - - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing VectorIORouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("VectorIORouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("VectorIORouter.shutdown") - pass - - async def register_vector_db( - self, - vector_db_id: str, - embedding_model: str, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> None: - logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") - await self.routing_table.register_vector_db( - vector_db_id, - embedding_model, - embedding_dimension, - provider_id, - provider_vector_db_id, - ) - - async def insert_chunks( - self, - vector_db_id: str, - chunks: list[Chunk], - ttl_seconds: int | None = None, - ) -> None: - logger.debug( - f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", - ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) - - async def query_chunks( - self, - vector_db_id: str, - query: InterleavedContent, - params: dict[str, Any] | None = None, - ) -> QueryChunksResponse: - logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) - - class InferenceRouter(Inference): """Routes to an provider based on the model""" @@ -141,10 +68,12 @@ class InferenceRouter(Inference): self, routing_table: RoutingTable, telemetry: Telemetry | None = None, + store: InferenceStore | None = None, ) -> None: logger.debug("Initializing InferenceRouter") self.routing_table = routing_table self.telemetry = telemetry + self.store = store if self.telemetry: self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(self.tokenizer) @@ -607,9 +536,31 @@ class InferenceRouter(Inference): provider = self.routing_table.get_provider_impl(model_obj.identifier) if stream: - return await provider.openai_chat_completion(**params) + response_stream = await provider.openai_chat_completion(**params) + if self.store: + return stream_and_store_openai_completion(response_stream, model, self.store, messages) + return response_stream else: - return await self._nonstream_openai_chat_completion(provider, params) + response = await self._nonstream_openai_chat_completion(provider, params) + if self.store: + await self.store.store_chat_completion(response, messages) + return response + + async def list_chat_completions( + self, + after: str | None = None, + limit: int | None = 20, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIChatCompletionResponse: + if self.store: + return await self.store.list_chat_completions(after, limit, model, order) + raise NotImplementedError("List chat completions is not supported: inference store is not configured.") + + async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: + if self.store: + return await self.store.get_chat_completion(completion_id) + raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: response = await provider.openai_chat_completion(**params) @@ -642,295 +593,3 @@ class InferenceRouter(Inference): status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" ) return health_statuses - - -class SafetyRouter(Safety): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing SafetyRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("SafetyRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("SafetyRouter.shutdown") - pass - - async def register_shield( - self, - shield_id: str, - provider_shield_id: str | None = None, - provider_id: str | None = None, - params: dict[str, Any] | None = None, - ) -> Shield: - logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) - - async def run_shield( - self, - shield_id: str, - messages: list[Message], - params: dict[str, Any] = None, - ) -> RunShieldResponse: - logger.debug(f"SafetyRouter.run_shield: {shield_id}") - return await self.routing_table.get_provider_impl(shield_id).run_shield( - shield_id=shield_id, - messages=messages, - params=params, - ) - - -class DatasetIORouter(DatasetIO): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing DatasetIORouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("DatasetIORouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("DatasetIORouter.shutdown") - pass - - async def register_dataset( - self, - purpose: DatasetPurpose, - source: DataSource, - metadata: dict[str, Any] | None = None, - dataset_id: str | None = None, - ) -> None: - logger.debug( - f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", - ) - await self.routing_table.register_dataset( - purpose=purpose, - source=source, - metadata=metadata, - dataset_id=dataset_id, - ) - - async def iterrows( - self, - dataset_id: str, - start_index: int | None = None, - limit: int | None = None, - ) -> PaginatedResponse: - logger.debug( - f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", - ) - return await self.routing_table.get_provider_impl(dataset_id).iterrows( - dataset_id=dataset_id, - start_index=start_index, - limit=limit, - ) - - async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: - logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") - return await self.routing_table.get_provider_impl(dataset_id).append_rows( - dataset_id=dataset_id, - rows=rows, - ) - - -class ScoringRouter(Scoring): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing ScoringRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("ScoringRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("ScoringRouter.shutdown") - pass - - async def score_batch( - self, - dataset_id: str, - scoring_functions: dict[str, ScoringFnParams | None] = None, - save_results_dataset: bool = False, - ) -> ScoreBatchResponse: - logger.debug(f"ScoringRouter.score_batch: {dataset_id}") - res = {} - for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( - dataset_id=dataset_id, - scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, - ) - res.update(score_response.results) - - if save_results_dataset: - raise NotImplementedError("Save results dataset not implemented yet") - - return ScoreBatchResponse( - results=res, - ) - - async def score( - self, - input_rows: list[dict[str, Any]], - scoring_functions: dict[str, ScoringFnParams | None] = None, - ) -> ScoreResponse: - logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") - res = {} - # look up and map each scoring function to its provider impl - for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( - input_rows=input_rows, - scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, - ) - res.update(score_response.results) - - return ScoreResponse(results=res) - - -class EvalRouter(Eval): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing EvalRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("EvalRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("EvalRouter.shutdown") - pass - - async def run_eval( - self, - benchmark_id: str, - benchmark_config: BenchmarkConfig, - ) -> Job: - logger.debug(f"EvalRouter.run_eval: {benchmark_id}") - return await self.routing_table.get_provider_impl(benchmark_id).run_eval( - benchmark_id=benchmark_id, - benchmark_config=benchmark_config, - ) - - async def evaluate_rows( - self, - benchmark_id: str, - input_rows: list[dict[str, Any]], - scoring_functions: list[str], - benchmark_config: BenchmarkConfig, - ) -> EvaluateResponse: - logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") - return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( - benchmark_id=benchmark_id, - input_rows=input_rows, - scoring_functions=scoring_functions, - benchmark_config=benchmark_config, - ) - - async def job_status( - self, - benchmark_id: str, - job_id: str, - ) -> Job: - logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) - - async def job_cancel( - self, - benchmark_id: str, - job_id: str, - ) -> None: - logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") - await self.routing_table.get_provider_impl(benchmark_id).job_cancel( - benchmark_id, - job_id, - ) - - async def job_result( - self, - benchmark_id: str, - job_id: str, - ) -> EvaluateResponse: - logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_result( - benchmark_id, - job_id, - ) - - -class ToolRuntimeRouter(ToolRuntime): - class RagToolImpl(RAGToolRuntime): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") - self.routing_table = routing_table - - async def query( - self, - content: InterleavedContent, - vector_db_ids: list[str], - query_config: RAGQueryConfig | None = None, - ) -> RAGQueryResult: - logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") - return await self.routing_table.get_provider_impl("knowledge_search").query( - content, vector_db_ids, query_config - ) - - async def insert( - self, - documents: list[RAGDocument], - vector_db_id: str, - chunk_size_in_tokens: int = 512, - ) -> None: - logger.debug( - f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" - ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) - - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing ToolRuntimeRouter") - self.routing_table = routing_table - - # HACK ALERT this should be in sync with "get_all_api_endpoints()" - self.rag_tool = self.RagToolImpl(routing_table) - for method in ("query", "insert"): - setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) - - async def initialize(self) -> None: - logger.debug("ToolRuntimeRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("ToolRuntimeRouter.shutdown") - pass - - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: - logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") - return await self.routing_table.get_provider_impl(tool_name).invoke_tool( - tool_name=tool_name, - kwargs=kwargs, - ) - - async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None - ) -> ListToolDefsResponse: - logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py deleted file mode 100644 index c04562197..000000000 --- a/llama_stack/distribution/routers/routing_tables.py +++ /dev/null @@ -1,634 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import logging -import time -import uuid -from typing import Any - -from pydantic import TypeAdapter - -from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse -from llama_stack.apis.common.content_types import URL -from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.datasets import ( - Dataset, - DatasetPurpose, - Datasets, - DatasetType, - DataSource, - ListDatasetsResponse, - RowsDataSource, - URIDataSource, -) -from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel -from llama_stack.apis.resource import ResourceType -from llama_stack.apis.scoring_functions import ( - ListScoringFunctionsResponse, - ScoringFn, - ScoringFnParams, - ScoringFunctions, -) -from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields -from llama_stack.apis.tools import ( - ListToolGroupsResponse, - ListToolsResponse, - Tool, - ToolGroup, - ToolGroups, - ToolHost, -) -from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs -from llama_stack.distribution.access_control import check_access -from llama_stack.distribution.datatypes import ( - AccessAttributes, - BenchmarkWithACL, - DatasetWithACL, - ModelWithACL, - RoutableObject, - RoutableObjectWithProvider, - RoutedProtocol, - ScoringFnWithACL, - ShieldWithACL, - ToolGroupWithACL, - ToolWithACL, - VectorDBWithACL, -) -from llama_stack.distribution.request_headers import get_auth_attributes -from llama_stack.distribution.store import DistributionRegistry -from llama_stack.providers.datatypes import Api, RoutingTable - -logger = logging.getLogger(__name__) - - -def get_impl_api(p: Any) -> Api: - return p.__provider_spec__.api - - -# TODO: this should return the registered object for all APIs -async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: - api = get_impl_api(p) - - assert obj.provider_id != "remote", "Remote provider should not be registered" - - if api == Api.inference: - return await p.register_model(obj) - elif api == Api.safety: - return await p.register_shield(obj) - elif api == Api.vector_io: - return await p.register_vector_db(obj) - elif api == Api.datasetio: - return await p.register_dataset(obj) - elif api == Api.scoring: - return await p.register_scoring_function(obj) - elif api == Api.eval: - return await p.register_benchmark(obj) - elif api == Api.tool_runtime: - return await p.register_tool(obj) - else: - raise ValueError(f"Unknown API {api} for registering object with provider") - - -async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: - api = get_impl_api(p) - if api == Api.vector_io: - return await p.unregister_vector_db(obj.identifier) - elif api == Api.inference: - return await p.unregister_model(obj.identifier) - elif api == Api.datasetio: - return await p.unregister_dataset(obj.identifier) - elif api == Api.tool_runtime: - return await p.unregister_tool(obj.identifier) - else: - raise ValueError(f"Unregister not supported for {api}") - - -Registry = dict[str, list[RoutableObjectWithProvider]] - - -class CommonRoutingTableImpl(RoutingTable): - def __init__( - self, - impls_by_provider_id: dict[str, RoutedProtocol], - dist_registry: DistributionRegistry, - ) -> None: - self.impls_by_provider_id = impls_by_provider_id - self.dist_registry = dist_registry - - async def initialize(self) -> None: - async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: - for obj in objs: - if cls is None: - obj.provider_id = provider_id - else: - # Create a copy of the model data and explicitly set provider_id - model_data = obj.model_dump() - model_data["provider_id"] = provider_id - obj = cls(**model_data) - await self.dist_registry.register(obj) - - # Register all objects from providers - for pid, p in self.impls_by_provider_id.items(): - api = get_impl_api(p) - if api == Api.inference: - p.model_store = self - elif api == Api.safety: - p.shield_store = self - elif api == Api.vector_io: - p.vector_db_store = self - elif api == Api.datasetio: - p.dataset_store = self - elif api == Api.scoring: - p.scoring_function_store = self - scoring_functions = await p.list_scoring_functions() - await add_objects(scoring_functions, pid, ScoringFn) - elif api == Api.eval: - p.benchmark_store = self - elif api == Api.tool_runtime: - p.tool_store = self - - async def shutdown(self) -> None: - for p in self.impls_by_provider_id.values(): - await p.shutdown() - - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: - def apiname_object(): - if isinstance(self, ModelsRoutingTable): - return ("Inference", "model") - elif isinstance(self, ShieldsRoutingTable): - return ("Safety", "shield") - elif isinstance(self, VectorDBsRoutingTable): - return ("VectorIO", "vector_db") - elif isinstance(self, DatasetsRoutingTable): - return ("DatasetIO", "dataset") - elif isinstance(self, ScoringFunctionsRoutingTable): - return ("Scoring", "scoring_function") - elif isinstance(self, BenchmarksRoutingTable): - return ("Eval", "benchmark") - elif isinstance(self, ToolGroupsRoutingTable): - return ("Tools", "tool") - else: - raise ValueError("Unknown routing table type") - - apiname, objtype = apiname_object() - - # Get objects from disk registry - obj = self.dist_registry.get_cached(objtype, routing_key) - if not obj: - provider_ids = list(self.impls_by_provider_id.keys()) - if len(provider_ids) > 1: - provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" - else: - provider_ids_str = f"provider: `{provider_ids[0]}`" - raise ValueError( - f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." - ) - - if not provider_id or provider_id == obj.provider_id: - return self.impls_by_provider_id[obj.provider_id] - - raise ValueError(f"Provider not found for `{routing_key}`") - - async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: - # Get from disk registry - obj = await self.dist_registry.get(type, identifier) - if not obj: - return None - - # Check if user has permission to access this object - if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): - logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") - return None - - return obj - - async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: - await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) - - async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: - # if provider_id is not specified, pick an arbitrary one from existing entries - if not obj.provider_id and len(self.impls_by_provider_id) > 0: - obj.provider_id = list(self.impls_by_provider_id.keys())[0] - - if obj.provider_id not in self.impls_by_provider_id: - raise ValueError(f"Provider `{obj.provider_id}` not found") - - p = self.impls_by_provider_id[obj.provider_id] - - # If object supports access control but no attributes set, use creator's attributes - if not obj.access_attributes: - creator_attributes = get_auth_attributes() - if creator_attributes: - obj.access_attributes = AccessAttributes(**creator_attributes) - logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") - - registered_obj = await register_object_with_provider(obj, p) - # TODO: This needs to be fixed for all APIs once they return the registered object - if obj.type == ResourceType.model.value: - await self.dist_registry.register(registered_obj) - return registered_obj - - else: - await self.dist_registry.register(obj) - return obj - - async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: - objs = await self.dist_registry.get_all() - filtered_objs = [obj for obj in objs if obj.type == type] - - # Apply attribute-based access control filtering - if filtered_objs: - filtered_objs = [ - obj - for obj in filtered_objs - if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) - ] - - return filtered_objs - - -class ModelsRoutingTable(CommonRoutingTableImpl, Models): - async def list_models(self) -> ListModelsResponse: - return ListModelsResponse(data=await self.get_all_with_type("model")) - - async def openai_list_models(self) -> OpenAIListModelsResponse: - models = await self.get_all_with_type("model") - openai_models = [ - OpenAIModel( - id=model.identifier, - object="model", - created=int(time.time()), - owned_by="llama_stack", - ) - for model in models - ] - return OpenAIListModelsResponse(data=openai_models) - - async def get_model(self, model_id: str) -> Model: - model = await self.get_object_by_identifier("model", model_id) - if model is None: - raise ValueError(f"Model '{model_id}' not found") - return model - - async def register_model( - self, - model_id: str, - provider_model_id: str | None = None, - provider_id: str | None = None, - metadata: dict[str, Any] | None = None, - model_type: ModelType | None = None, - ) -> Model: - if provider_model_id is None: - provider_model_id = model_id - if provider_id is None: - # If provider_id not specified, use the only provider if it supports this model - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" - ) - if metadata is None: - metadata = {} - if model_type is None: - model_type = ModelType.llm - if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") - model = ModelWithACL( - identifier=model_id, - provider_resource_id=provider_model_id, - provider_id=provider_id, - metadata=metadata, - model_type=model_type, - ) - registered_model = await self.register_object(model) - return registered_model - - async def unregister_model(self, model_id: str) -> None: - existing_model = await self.get_model(model_id) - if existing_model is None: - raise ValueError(f"Model {model_id} not found") - await self.unregister_object(existing_model) - - -class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): - async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) - - async def get_shield(self, identifier: str) -> Shield: - shield = await self.get_object_by_identifier("shield", identifier) - if shield is None: - raise ValueError(f"Shield '{identifier}' not found") - return shield - - async def register_shield( - self, - shield_id: str, - provider_shield_id: str | None = None, - provider_id: str | None = None, - params: dict[str, Any] | None = None, - ) -> Shield: - if provider_shield_id is None: - provider_shield_id = shield_id - if provider_id is None: - # If provider_id not specified, use the only provider if it supports this shield type - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - if params is None: - params = {} - shield = ShieldWithACL( - identifier=shield_id, - provider_resource_id=provider_shield_id, - provider_id=provider_id, - params=params, - ) - await self.register_object(shield) - return shield - - -class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): - async def list_vector_dbs(self) -> ListVectorDBsResponse: - return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) - - async def get_vector_db(self, vector_db_id: str) -> VectorDB: - vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) - if vector_db is None: - raise ValueError(f"Vector DB '{vector_db_id}' not found") - return vector_db - - async def register_vector_db( - self, - vector_db_id: str, - embedding_model: str, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorDB: - if provider_vector_db_id is None: - provider_vector_db_id = vector_db_id - if provider_id is None: - if len(self.impls_by_provider_id) > 0: - provider_id = list(self.impls_by_provider_id.keys())[0] - if len(self.impls_by_provider_id) > 1: - logger.warning( - f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." - ) - else: - raise ValueError("No provider available. Please configure a vector_io provider.") - model = await self.get_object_by_identifier("model", embedding_model) - if model is None: - raise ValueError(f"Model {embedding_model} not found") - if model.model_type != ModelType.embedding: - raise ValueError(f"Model {embedding_model} is not an embedding model") - if "embedding_dimension" not in model.metadata: - raise ValueError(f"Model {embedding_model} does not have an embedding dimension") - vector_db_data = { - "identifier": vector_db_id, - "type": ResourceType.vector_db.value, - "provider_id": provider_id, - "provider_resource_id": provider_vector_db_id, - "embedding_model": embedding_model, - "embedding_dimension": model.metadata["embedding_dimension"], - } - vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) - await self.register_object(vector_db) - return vector_db - - async def unregister_vector_db(self, vector_db_id: str) -> None: - existing_vector_db = await self.get_vector_db(vector_db_id) - if existing_vector_db is None: - raise ValueError(f"Vector DB {vector_db_id} not found") - await self.unregister_object(existing_vector_db) - - -class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): - async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) - - async def get_dataset(self, dataset_id: str) -> Dataset: - dataset = await self.get_object_by_identifier("dataset", dataset_id) - if dataset is None: - raise ValueError(f"Dataset '{dataset_id}' not found") - return dataset - - async def register_dataset( - self, - purpose: DatasetPurpose, - source: DataSource, - metadata: dict[str, Any] | None = None, - dataset_id: str | None = None, - ) -> Dataset: - if isinstance(source, dict): - if source["type"] == "uri": - source = URIDataSource.parse_obj(source) - elif source["type"] == "rows": - source = RowsDataSource.parse_obj(source) - - if not dataset_id: - dataset_id = f"dataset-{str(uuid.uuid4())}" - - provider_dataset_id = dataset_id - - # infer provider from source - if metadata: - if metadata.get("provider_id"): - provider_id = metadata.get("provider_id") # pass through from nvidia datasetio - elif source.type == DatasetType.rows.value: - provider_id = "localfs" - elif source.type == DatasetType.uri.value: - # infer provider from uri - if source.uri.startswith("huggingface"): - provider_id = "huggingface" - else: - provider_id = "localfs" - else: - raise ValueError(f"Unknown data source type: {source.type}") - - if metadata is None: - metadata = {} - - dataset = DatasetWithACL( - identifier=dataset_id, - provider_resource_id=provider_dataset_id, - provider_id=provider_id, - purpose=purpose, - source=source, - metadata=metadata, - ) - - await self.register_object(dataset) - return dataset - - async def unregister_dataset(self, dataset_id: str) -> None: - dataset = await self.get_dataset(dataset_id) - if dataset is None: - raise ValueError(f"Dataset {dataset_id} not found") - await self.unregister_object(dataset) - - -class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): - async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) - - async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: - scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) - if scoring_fn is None: - raise ValueError(f"Scoring function '{scoring_fn_id}' not found") - return scoring_fn - - async def register_scoring_function( - self, - scoring_fn_id: str, - description: str, - return_type: ParamType, - provider_scoring_fn_id: str | None = None, - provider_id: str | None = None, - params: ScoringFnParams | None = None, - ) -> None: - if provider_scoring_fn_id is None: - provider_scoring_fn_id = scoring_fn_id - if provider_id is None: - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - scoring_fn = ScoringFnWithACL( - identifier=scoring_fn_id, - description=description, - return_type=return_type, - provider_resource_id=provider_scoring_fn_id, - provider_id=provider_id, - params=params, - ) - scoring_fn.provider_id = provider_id - await self.register_object(scoring_fn) - - -class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): - async def list_benchmarks(self) -> ListBenchmarksResponse: - return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) - - async def get_benchmark(self, benchmark_id: str) -> Benchmark: - benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) - if benchmark is None: - raise ValueError(f"Benchmark '{benchmark_id}' not found") - return benchmark - - async def register_benchmark( - self, - benchmark_id: str, - dataset_id: str, - scoring_functions: list[str], - metadata: dict[str, Any] | None = None, - provider_benchmark_id: str | None = None, - provider_id: str | None = None, - ) -> None: - if metadata is None: - metadata = {} - if provider_id is None: - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - if provider_benchmark_id is None: - provider_benchmark_id = benchmark_id - benchmark = BenchmarkWithACL( - identifier=benchmark_id, - dataset_id=dataset_id, - scoring_functions=scoring_functions, - metadata=metadata, - provider_id=provider_id, - provider_resource_id=provider_benchmark_id, - ) - await self.register_object(benchmark) - - -class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): - async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: - tools = await self.get_all_with_type("tool") - if toolgroup_id: - tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] - return ListToolsResponse(data=tools) - - async def list_tool_groups(self) -> ListToolGroupsResponse: - return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) - - async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: - tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) - if tool_group is None: - raise ValueError(f"Tool group '{toolgroup_id}' not found") - return tool_group - - async def get_tool(self, tool_name: str) -> Tool: - return await self.get_object_by_identifier("tool", tool_name) - - async def register_tool_group( - self, - toolgroup_id: str, - provider_id: str, - mcp_endpoint: URL | None = None, - args: dict[str, Any] | None = None, - ) -> None: - tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution - - for tool_def in tool_defs.data: - tools.append( - ToolWithACL( - identifier=tool_def.name, - toolgroup_id=toolgroup_id, - description=tool_def.description or "", - parameters=tool_def.parameters or [], - provider_id=provider_id, - provider_resource_id=tool_def.name, - metadata=tool_def.metadata, - tool_host=tool_host, - ) - ) - for tool in tools: - existing_tool = await self.get_tool(tool.identifier) - # Compare existing and new object if one exists - if existing_tool: - existing_dict = existing_tool.model_dump() - new_dict = tool.model_dump() - - if existing_dict != new_dict: - raise ValueError( - f"Object {tool.identifier} already exists in registry. Please use a different identifier." - ) - await self.register_object(tool) - - await self.dist_registry.register( - ToolGroupWithACL( - identifier=toolgroup_id, - provider_id=provider_id, - provider_resource_id=toolgroup_id, - mcp_endpoint=mcp_endpoint, - args=args, - ) - ) - - async def unregister_toolgroup(self, toolgroup_id: str) -> None: - tool_group = await self.get_tool_group(toolgroup_id) - if tool_group is None: - raise ValueError(f"Tool group {toolgroup_id} not found") - tools = await self.list_tools(toolgroup_id) - for tool in getattr(tools, "data", []): - await self.unregister_object(tool) - await self.unregister_object(tool_group) - - async def shutdown(self) -> None: - pass diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py new file mode 100644 index 000000000..9761d2db0 --- /dev/null +++ b/llama_stack/distribution/routers/safety.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.inference import ( + Message, +) +from llama_stack.apis.safety import RunShieldResponse, Safety +from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class SafetyRouter(Safety): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing SafetyRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("SafetyRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("SafetyRouter.shutdown") + pass + + async def register_shield( + self, + shield_id: str, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, + ) -> Shield: + logger.debug(f"SafetyRouter.register_shield: {shield_id}") + return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + + async def run_shield( + self, + shield_id: str, + messages: list[Message], + params: dict[str, Any] = None, + ) -> RunShieldResponse: + logger.debug(f"SafetyRouter.run_shield: {shield_id}") + return await self.routing_table.get_provider_impl(shield_id).run_shield( + shield_id=shield_id, + messages=messages, + params=params, + ) diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/distribution/routers/tool_runtime.py new file mode 100644 index 000000000..285843dbc --- /dev/null +++ b/llama_stack/distribution/routers/tool_runtime.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.common.content_types import ( + URL, + InterleavedContent, +) +from llama_stack.apis.tools import ( + ListToolsResponse, + RAGDocument, + RAGQueryConfig, + RAGQueryResult, + RAGToolRuntime, + ToolRuntime, +) +from llama_stack.log import get_logger + +from ..routing_tables.toolgroups import ToolGroupsRoutingTable + +logger = get_logger(name=__name__, category="core") + + +class ToolRuntimeRouter(ToolRuntime): + class RagToolImpl(RAGToolRuntime): + def __init__( + self, + routing_table: ToolGroupsRoutingTable, + ) -> None: + logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") + self.routing_table = routing_table + + async def query( + self, + content: InterleavedContent, + vector_db_ids: list[str], + query_config: RAGQueryConfig | None = None, + ) -> RAGQueryResult: + logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") + return await self.routing_table.get_provider_impl("knowledge_search").query( + content, vector_db_ids, query_config + ) + + async def insert( + self, + documents: list[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + logger.debug( + f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" + ) + return await self.routing_table.get_provider_impl("insert_into_memory").insert( + documents, vector_db_id, chunk_size_in_tokens + ) + + def __init__( + self, + routing_table: ToolGroupsRoutingTable, + ) -> None: + logger.debug("Initializing ToolRuntimeRouter") + self.routing_table = routing_table + + # HACK ALERT this should be in sync with "get_all_api_endpoints()" + self.rag_tool = self.RagToolImpl(routing_table) + for method in ("query", "insert"): + setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) + + async def initialize(self) -> None: + logger.debug("ToolRuntimeRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("ToolRuntimeRouter.shutdown") + pass + + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: + logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") + return await self.routing_table.get_provider_impl(tool_name).invoke_tool( + tool_name=tool_name, + kwargs=kwargs, + ) + + async def list_runtime_tools( + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + ) -> ListToolsResponse: + logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") + return await self.routing_table.list_tools(tool_group_id) diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py new file mode 100644 index 000000000..8c17aa890 --- /dev/null +++ b/llama_stack/distribution/routers/vector_io.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.common.content_types import ( + InterleavedContent, +) +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class VectorIORouter(VectorIO): + """Routes to an provider based on the vector db identifier""" + + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing VectorIORouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("VectorIORouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("VectorIORouter.shutdown") + pass + + async def register_vector_db( + self, + vector_db_id: str, + embedding_model: str, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + ) -> None: + logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") + await self.routing_table.register_vector_db( + vector_db_id, + embedding_model, + embedding_dimension, + provider_id, + provider_vector_db_id, + ) + + async def insert_chunks( + self, + vector_db_id: str, + chunks: list[Chunk], + ttl_seconds: int | None = None, + ) -> None: + logger.debug( + f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", + ) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) diff --git a/llama_stack/distribution/routing_tables/__init__.py b/llama_stack/distribution/routing_tables/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/routing_tables/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/routing_tables/benchmarks.py b/llama_stack/distribution/routing_tables/benchmarks.py new file mode 100644 index 000000000..589a00c02 --- /dev/null +++ b/llama_stack/distribution/routing_tables/benchmarks.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse +from llama_stack.distribution.datatypes import ( + BenchmarkWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): + async def list_benchmarks(self) -> ListBenchmarksResponse: + return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) + + async def get_benchmark(self, benchmark_id: str) -> Benchmark: + benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) + if benchmark is None: + raise ValueError(f"Benchmark '{benchmark_id}' not found") + return benchmark + + async def register_benchmark( + self, + benchmark_id: str, + dataset_id: str, + scoring_functions: list[str], + metadata: dict[str, Any] | None = None, + provider_benchmark_id: str | None = None, + provider_id: str | None = None, + ) -> None: + if metadata is None: + metadata = {} + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if provider_benchmark_id is None: + provider_benchmark_id = benchmark_id + benchmark = BenchmarkWithACL( + identifier=benchmark_id, + dataset_id=dataset_id, + scoring_functions=scoring_functions, + metadata=metadata, + provider_id=provider_id, + provider_resource_id=provider_benchmark_id, + ) + await self.register_object(benchmark) diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py new file mode 100644 index 000000000..8ec87ca50 --- /dev/null +++ b/llama_stack/distribution/routing_tables/common.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.distribution.access_control import check_access +from llama_stack.distribution.datatypes import ( + AccessAttributes, + RoutableObject, + RoutableObjectWithProvider, + RoutedProtocol, +) +from llama_stack.distribution.request_headers import get_auth_attributes +from llama_stack.distribution.store import DistributionRegistry +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import Api, RoutingTable + +logger = get_logger(name=__name__, category="core") + + +def get_impl_api(p: Any) -> Api: + return p.__provider_spec__.api + + +# TODO: this should return the registered object for all APIs +async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: + api = get_impl_api(p) + + assert obj.provider_id != "remote", "Remote provider should not be registered" + + if api == Api.inference: + return await p.register_model(obj) + elif api == Api.safety: + return await p.register_shield(obj) + elif api == Api.vector_io: + return await p.register_vector_db(obj) + elif api == Api.datasetio: + return await p.register_dataset(obj) + elif api == Api.scoring: + return await p.register_scoring_function(obj) + elif api == Api.eval: + return await p.register_benchmark(obj) + elif api == Api.tool_runtime: + return await p.register_toolgroup(obj) + else: + raise ValueError(f"Unknown API {api} for registering object with provider") + + +async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: + api = get_impl_api(p) + if api == Api.vector_io: + return await p.unregister_vector_db(obj.identifier) + elif api == Api.inference: + return await p.unregister_model(obj.identifier) + elif api == Api.datasetio: + return await p.unregister_dataset(obj.identifier) + elif api == Api.tool_runtime: + return await p.unregister_toolgroup(obj.identifier) + else: + raise ValueError(f"Unregister not supported for {api}") + + +Registry = dict[str, list[RoutableObjectWithProvider]] + + +class CommonRoutingTableImpl(RoutingTable): + def __init__( + self, + impls_by_provider_id: dict[str, RoutedProtocol], + dist_registry: DistributionRegistry, + ) -> None: + self.impls_by_provider_id = impls_by_provider_id + self.dist_registry = dist_registry + + async def initialize(self) -> None: + async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: + for obj in objs: + if cls is None: + obj.provider_id = provider_id + else: + # Create a copy of the model data and explicitly set provider_id + model_data = obj.model_dump() + model_data["provider_id"] = provider_id + obj = cls(**model_data) + await self.dist_registry.register(obj) + + # Register all objects from providers + for pid, p in self.impls_by_provider_id.items(): + api = get_impl_api(p) + if api == Api.inference: + p.model_store = self + elif api == Api.safety: + p.shield_store = self + elif api == Api.vector_io: + p.vector_db_store = self + elif api == Api.datasetio: + p.dataset_store = self + elif api == Api.scoring: + p.scoring_function_store = self + scoring_functions = await p.list_scoring_functions() + await add_objects(scoring_functions, pid, ScoringFn) + elif api == Api.eval: + p.benchmark_store = self + elif api == Api.tool_runtime: + p.tool_store = self + + async def shutdown(self) -> None: + for p in self.impls_by_provider_id.values(): + await p.shutdown() + + def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + from .benchmarks import BenchmarksRoutingTable + from .datasets import DatasetsRoutingTable + from .models import ModelsRoutingTable + from .scoring_functions import ScoringFunctionsRoutingTable + from .shields import ShieldsRoutingTable + from .toolgroups import ToolGroupsRoutingTable + from .vector_dbs import VectorDBsRoutingTable + + def apiname_object(): + if isinstance(self, ModelsRoutingTable): + return ("Inference", "model") + elif isinstance(self, ShieldsRoutingTable): + return ("Safety", "shield") + elif isinstance(self, VectorDBsRoutingTable): + return ("VectorIO", "vector_db") + elif isinstance(self, DatasetsRoutingTable): + return ("DatasetIO", "dataset") + elif isinstance(self, ScoringFunctionsRoutingTable): + return ("Scoring", "scoring_function") + elif isinstance(self, BenchmarksRoutingTable): + return ("Eval", "benchmark") + elif isinstance(self, ToolGroupsRoutingTable): + return ("ToolGroups", "tool_group") + else: + raise ValueError("Unknown routing table type") + + apiname, objtype = apiname_object() + + # Get objects from disk registry + obj = self.dist_registry.get_cached(objtype, routing_key) + if not obj: + provider_ids = list(self.impls_by_provider_id.keys()) + if len(provider_ids) > 1: + provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" + else: + provider_ids_str = f"provider: `{provider_ids[0]}`" + raise ValueError( + f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." + ) + + if not provider_id or provider_id == obj.provider_id: + return self.impls_by_provider_id[obj.provider_id] + + raise ValueError(f"Provider not found for `{routing_key}`") + + async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + # Get from disk registry + obj = await self.dist_registry.get(type, identifier) + if not obj: + return None + + # Check if user has permission to access this object + if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): + logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") + return None + + return obj + + async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: + await self.dist_registry.delete(obj.type, obj.identifier) + await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + + async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: + # if provider_id is not specified, pick an arbitrary one from existing entries + if not obj.provider_id and len(self.impls_by_provider_id) > 0: + obj.provider_id = list(self.impls_by_provider_id.keys())[0] + + if obj.provider_id not in self.impls_by_provider_id: + raise ValueError(f"Provider `{obj.provider_id}` not found") + + p = self.impls_by_provider_id[obj.provider_id] + + # If object supports access control but no attributes set, use creator's attributes + if not obj.access_attributes: + creator_attributes = get_auth_attributes() + if creator_attributes: + obj.access_attributes = AccessAttributes(**creator_attributes) + logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") + + registered_obj = await register_object_with_provider(obj, p) + # TODO: This needs to be fixed for all APIs once they return the registered object + if obj.type == ResourceType.model.value: + await self.dist_registry.register(registered_obj) + return registered_obj + + else: + await self.dist_registry.register(obj) + return obj + + async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: + objs = await self.dist_registry.get_all() + filtered_objs = [obj for obj in objs if obj.type == type] + + # Apply attribute-based access control filtering + if filtered_objs: + filtered_objs = [ + obj + for obj in filtered_objs + if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) + ] + + return filtered_objs diff --git a/llama_stack/distribution/routing_tables/datasets.py b/llama_stack/distribution/routing_tables/datasets.py new file mode 100644 index 000000000..4401ad47e --- /dev/null +++ b/llama_stack/distribution/routing_tables/datasets.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid +from typing import Any + +from llama_stack.apis.datasets import ( + Dataset, + DatasetPurpose, + Datasets, + DatasetType, + DataSource, + ListDatasetsResponse, + RowsDataSource, + URIDataSource, +) +from llama_stack.apis.resource import ResourceType +from llama_stack.distribution.datatypes import ( + DatasetWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): + async def list_datasets(self) -> ListDatasetsResponse: + return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + + async def get_dataset(self, dataset_id: str) -> Dataset: + dataset = await self.get_object_by_identifier("dataset", dataset_id) + if dataset is None: + raise ValueError(f"Dataset '{dataset_id}' not found") + return dataset + + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, + ) -> Dataset: + if isinstance(source, dict): + if source["type"] == "uri": + source = URIDataSource.parse_obj(source) + elif source["type"] == "rows": + source = RowsDataSource.parse_obj(source) + + if not dataset_id: + dataset_id = f"dataset-{str(uuid.uuid4())}" + + provider_dataset_id = dataset_id + + # infer provider from source + if metadata: + if metadata.get("provider_id"): + provider_id = metadata.get("provider_id") # pass through from nvidia datasetio + elif source.type == DatasetType.rows.value: + provider_id = "localfs" + elif source.type == DatasetType.uri.value: + # infer provider from uri + if source.uri.startswith("huggingface"): + provider_id = "huggingface" + else: + provider_id = "localfs" + else: + raise ValueError(f"Unknown data source type: {source.type}") + + if metadata is None: + metadata = {} + + dataset = DatasetWithACL( + identifier=dataset_id, + provider_resource_id=provider_dataset_id, + provider_id=provider_id, + purpose=purpose, + source=source, + metadata=metadata, + ) + + await self.register_object(dataset) + return dataset + + async def unregister_dataset(self, dataset_id: str) -> None: + dataset = await self.get_dataset(dataset_id) + if dataset is None: + raise ValueError(f"Dataset {dataset_id} not found") + await self.unregister_object(dataset) diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py new file mode 100644 index 000000000..7216d9935 --- /dev/null +++ b/llama_stack/distribution/routing_tables/models.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from typing import Any + +from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel +from llama_stack.distribution.datatypes import ( + ModelWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ModelsRoutingTable(CommonRoutingTableImpl, Models): + async def list_models(self) -> ListModelsResponse: + return ListModelsResponse(data=await self.get_all_with_type("model")) + + async def openai_list_models(self) -> OpenAIListModelsResponse: + models = await self.get_all_with_type("model") + openai_models = [ + OpenAIModel( + id=model.identifier, + object="model", + created=int(time.time()), + owned_by="llama_stack", + ) + for model in models + ] + return OpenAIListModelsResponse(data=openai_models) + + async def get_model(self, model_id: str) -> Model: + model = await self.get_object_by_identifier("model", model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + return model + + async def register_model( + self, + model_id: str, + provider_model_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, Any] | None = None, + model_type: ModelType | None = None, + ) -> Model: + if provider_model_id is None: + provider_model_id = model_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this model + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" + ) + if metadata is None: + metadata = {} + if model_type is None: + model_type = ModelType.llm + if "embedding_dimension" not in metadata and model_type == ModelType.embedding: + raise ValueError("Embedding model must have an embedding dimension in its metadata") + model = ModelWithACL( + identifier=model_id, + provider_resource_id=provider_model_id, + provider_id=provider_id, + metadata=metadata, + model_type=model_type, + ) + registered_model = await self.register_object(model) + return registered_model + + async def unregister_model(self, model_id: str) -> None: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + await self.unregister_object(existing_model) diff --git a/llama_stack/distribution/routing_tables/scoring_functions.py b/llama_stack/distribution/routing_tables/scoring_functions.py new file mode 100644 index 000000000..d85f64b57 --- /dev/null +++ b/llama_stack/distribution/routing_tables/scoring_functions.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.scoring_functions import ( + ListScoringFunctionsResponse, + ScoringFn, + ScoringFnParams, + ScoringFunctions, +) +from llama_stack.distribution.datatypes import ( + ScoringFnWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): + async def list_scoring_functions(self) -> ListScoringFunctionsResponse: + return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + + async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: + scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) + if scoring_fn is None: + raise ValueError(f"Scoring function '{scoring_fn_id}' not found") + return scoring_fn + + async def register_scoring_function( + self, + scoring_fn_id: str, + description: str, + return_type: ParamType, + provider_scoring_fn_id: str | None = None, + provider_id: str | None = None, + params: ScoringFnParams | None = None, + ) -> None: + if provider_scoring_fn_id is None: + provider_scoring_fn_id = scoring_fn_id + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + scoring_fn = ScoringFnWithACL( + identifier=scoring_fn_id, + description=description, + return_type=return_type, + provider_resource_id=provider_scoring_fn_id, + provider_id=provider_id, + params=params, + ) + scoring_fn.provider_id = provider_id + await self.register_object(scoring_fn) diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py new file mode 100644 index 000000000..7f62596c9 --- /dev/null +++ b/llama_stack/distribution/routing_tables/shields.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields +from llama_stack.distribution.datatypes import ( + ShieldWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): + async def list_shields(self) -> ListShieldsResponse: + return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + + async def get_shield(self, identifier: str) -> Shield: + shield = await self.get_object_by_identifier("shield", identifier) + if shield is None: + raise ValueError(f"Shield '{identifier}' not found") + return shield + + async def register_shield( + self, + shield_id: str, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, + ) -> Shield: + if provider_shield_id is None: + provider_shield_id = shield_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this shield type + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if params is None: + params = {} + shield = ShieldWithACL( + identifier=shield_id, + provider_resource_id=provider_shield_id, + provider_id=provider_id, + params=params, + ) + await self.register_object(shield) + return shield diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py new file mode 100644 index 000000000..2f7dc3e06 --- /dev/null +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups +from llama_stack.distribution.datatypes import ToolGroupWithACL +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: + # handle the funny case like "builtin::rag/knowledge_search" + parts = toolgroup_name_with_maybe_tool_name.split("/") + if len(parts) == 2: + return parts[0] + else: + return None + + +class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): + toolgroups_to_tools: dict[str, list[Tool]] = {} + tool_to_toolgroup: dict[str, str] = {} + + # overridden + def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id + # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? + + toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key) + if toolgroup_id: + routing_key = toolgroup_id + + if routing_key in self.tool_to_toolgroup: + routing_key = self.tool_to_toolgroup[routing_key] + return super().get_provider_impl(routing_key, provider_id) + + async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: + if toolgroup_id: + if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id): + toolgroup_id = group_id + toolgroups = [await self.get_tool_group(toolgroup_id)] + else: + toolgroups = await self.get_all_with_type("tool_group") + + all_tools = [] + for toolgroup in toolgroups: + if toolgroup.identifier not in self.toolgroups_to_tools: + await self._index_tools(toolgroup) + all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier]) + + return ListToolsResponse(data=all_tools) + + async def _index_tools(self, toolgroup: ToolGroup): + provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) + tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint) + + # TODO: kill this Tool vs ToolDef distinction + tooldefs = tooldefs_response.data + tools = [] + for t in tooldefs: + tools.append( + Tool( + identifier=t.name, + toolgroup_id=toolgroup.identifier, + description=t.description or "", + parameters=t.parameters or [], + metadata=t.metadata, + provider_id=toolgroup.provider_id, + ) + ) + + self.toolgroups_to_tools[toolgroup.identifier] = tools + for tool in tools: + self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier + + async def list_tool_groups(self) -> ListToolGroupsResponse: + return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) + + async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: + tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) + if tool_group is None: + raise ValueError(f"Tool group '{toolgroup_id}' not found") + return tool_group + + async def get_tool(self, tool_name: str) -> Tool: + if tool_name in self.tool_to_toolgroup: + toolgroup_id = self.tool_to_toolgroup[tool_name] + tools = self.toolgroups_to_tools[toolgroup_id] + for tool in tools: + if tool.identifier == tool_name: + return tool + raise ValueError(f"Tool '{tool_name}' not found") + + async def register_tool_group( + self, + toolgroup_id: str, + provider_id: str, + mcp_endpoint: URL | None = None, + args: dict[str, Any] | None = None, + ) -> None: + toolgroup = ToolGroupWithACL( + identifier=toolgroup_id, + provider_id=provider_id, + provider_resource_id=toolgroup_id, + mcp_endpoint=mcp_endpoint, + args=args, + ) + await self.register_object(toolgroup) + + # ideally, indexing of the tools should not be necessary because anyone using + # the tools should first list the tools and then use them. but there are assumptions + # baked in some of the code and tests right now. + if not toolgroup.mcp_endpoint: + await self._index_tools(toolgroup) + return toolgroup + + async def unregister_toolgroup(self, toolgroup_id: str) -> None: + tool_group = await self.get_tool_group(toolgroup_id) + if tool_group is None: + raise ValueError(f"Tool group {toolgroup_id} not found") + await self.unregister_object(tool_group) + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py new file mode 100644 index 000000000..dc6c0d0ef --- /dev/null +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import TypeAdapter + +from llama_stack.apis.models import ModelType +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs +from llama_stack.distribution.datatypes import ( + VectorDBWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): + async def list_vector_dbs(self) -> ListVectorDBsResponse: + return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) + + async def get_vector_db(self, vector_db_id: str) -> VectorDB: + vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) + if vector_db is None: + raise ValueError(f"Vector DB '{vector_db_id}' not found") + return vector_db + + async def register_vector_db( + self, + vector_db_id: str, + embedding_model: str, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + ) -> VectorDB: + if provider_vector_db_id is None: + provider_vector_db_id = vector_db_id + if provider_id is None: + if len(self.impls_by_provider_id) > 0: + provider_id = list(self.impls_by_provider_id.keys())[0] + if len(self.impls_by_provider_id) > 1: + logger.warning( + f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." + ) + else: + raise ValueError("No provider available. Please configure a vector_io provider.") + model = await self.get_object_by_identifier("model", embedding_model) + if model is None: + raise ValueError(f"Model {embedding_model} not found") + if model.model_type != ModelType.embedding: + raise ValueError(f"Model {embedding_model} is not an embedding model") + if "embedding_dimension" not in model.metadata: + raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + vector_db_data = { + "identifier": vector_db_id, + "type": ResourceType.vector_db.value, + "provider_id": provider_id, + "provider_resource_id": provider_vector_db_id, + "embedding_model": embedding_model, + "embedding_dimension": model.metadata["embedding_dimension"], + } + vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) + await self.register_object(vector_db) + return vector_db + + async def unregister_vector_db(self, vector_db_id: str) -> None: + existing_vector_db = await self.get_vector_db(vector_db_id) + if existing_vector_db is None: + raise ValueError(f"Vector DB {vector_db_id} not found") + await self.unregister_object(existing_vector_db) diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index 83436c51f..fb26b49a7 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -8,7 +8,8 @@ import json import httpx -from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider +from llama_stack.distribution.datatypes import AuthenticationConfig +from llama_stack.distribution.server.auth_providers import create_auth_provider from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -77,7 +78,7 @@ class AuthenticationMiddleware: access resources that don't have access_attributes defined. """ - def __init__(self, app, auth_config: AuthProviderConfig): + def __init__(self, app, auth_config: AuthenticationConfig): self.app = app self.auth_provider = create_auth_provider(auth_config) @@ -113,6 +114,10 @@ class AuthenticationMiddleware: "roles": [token], } + # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) + # can identify the requester and enforce per-client rate limits. + scope["authenticated_client_id"] = token + # Store attributes in request scope scope["user_attributes"] = user_attributes scope["principal"] = validation_result.principal diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 4065a65f3..723a65b77 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -4,17 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json +import ssl import time from abc import ABC, abstractmethod -from enum import Enum +from asyncio import Lock +from pathlib import Path from urllib.parse import parse_qs import httpx from jose import jwt -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self -from llama_stack.distribution.datatypes import AccessAttributes +from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -72,21 +74,6 @@ class AuthRequest(BaseModel): request: AuthRequestContext = Field(description="Context information about the request being authenticated") -class AuthProviderType(str, Enum): - """Supported authentication provider types.""" - - KUBERNETES = "kubernetes" - CUSTOM = "custom" - OAUTH2_TOKEN = "oauth2_token" - - -class AuthProviderConfig(BaseModel): - """Base configuration for authentication providers.""" - - provider_type: AuthProviderType = Field(..., description="Type of authentication provider") - config: dict[str, str] = Field(..., description="Provider-specific configuration") - - class AuthProvider(ABC): """Abstract base class for authentication providers.""" @@ -101,83 +88,6 @@ class AuthProvider(ABC): pass -class KubernetesAuthProviderConfig(BaseModel): - api_server_url: str - ca_cert_path: str | None = None - - -class KubernetesAuthProvider(AuthProvider): - """Kubernetes authentication provider that validates tokens against the Kubernetes API server.""" - - def __init__(self, config: KubernetesAuthProviderConfig): - self.config = config - self._client = None - - async def _get_client(self): - """Get or create a Kubernetes client.""" - if self._client is None: - # kubernetes-client has not async support, see: - # https://github.com/kubernetes-client/python/issues/323 - from kubernetes import client - from kubernetes.client import ApiClient - - # Configure the client - configuration = client.Configuration() - configuration.host = self.config.api_server_url - if self.config.ca_cert_path: - configuration.ssl_ca_cert = self.config.ca_cert_path - configuration.verify_ssl = bool(self.config.ca_cert_path) - - # Create API client - self._client = ApiClient(configuration) - return self._client - - async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: - """Validate a Kubernetes token and return access attributes.""" - try: - client = await self._get_client() - - # Set the token in the client - client.set_default_header("Authorization", f"Bearer {token}") - - # Make a request to validate the token - # We use the /api endpoint which requires authentication - from kubernetes.client import CoreV1Api - - api = CoreV1Api(client) - api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request - - # If we get here, the token is valid - # Extract user info from the token claims - import base64 - - # Decode the token (without verification since we've already validated it) - token_parts = token.split(".") - payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4))) - - # Extract user information from the token - username = payload.get("sub", "") - groups = payload.get("groups", []) - - return TokenValidationResult( - principal=username, - access_attributes=AccessAttributes( - roles=[username], # Use username as a role - teams=groups, # Use Kubernetes groups as teams - ), - ) - - except Exception as e: - logger.exception("Failed to validate Kubernetes token") - raise ValueError("Invalid or expired token") from e - - async def close(self): - """Close the HTTP client.""" - if self._client: - self._client.close() - self._client = None - - def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes: attributes = AccessAttributes() for claim_key, attribute_key in mapping.items(): @@ -197,11 +107,24 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) return attributes -class OAuth2TokenAuthProviderConfig(BaseModel): +class OAuth2JWKSConfig(BaseModel): # The JWKS URI for collecting public keys - jwks_uri: str - cache_ttl: int = 3600 + uri: str + key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates") + + +class OAuth2IntrospectionConfig(BaseModel): + url: str + client_id: str + client_secret: str + send_secret_in_body: bool = False + + +class OAuth2TokenAuthProviderConfig(BaseModel): audience: str = "llama-stack" + verify_tls: bool = True + tls_cafile: Path | None = None + issuer: str | None = Field(default=None, description="The OIDC issuer URL.") claims_mapping: dict[str, str] = Field( default_factory=lambda: { "sub": "roles", @@ -213,6 +136,8 @@ class OAuth2TokenAuthProviderConfig(BaseModel): "namespace": "namespaces", }, ) + jwks: OAuth2JWKSConfig | None + introspection: OAuth2IntrospectionConfig | None = None @classmethod @field_validator("claims_mapping") @@ -224,6 +149,14 @@ class OAuth2TokenAuthProviderConfig(BaseModel): raise ValueError(f"claims_mapping value is not a valid attribute: {value}") return v + @model_validator(mode="after") + def validate_mode(self) -> Self: + if not self.jwks and not self.introspection: + raise ValueError("One of jwks or introspection must be configured") + if self.jwks and self.introspection: + raise ValueError("At present only one of jwks or introspection should be configured") + return self + class OAuth2TokenAuthProvider(AuthProvider): """ @@ -236,8 +169,16 @@ class OAuth2TokenAuthProvider(AuthProvider): self.config = config self._jwks_at: float = 0.0 self._jwks: dict[str, str] = {} + self._jwks_lock = Lock() async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + if self.config.jwks: + return await self.validate_jwt_token(token, scope) + if self.config.introspection: + return await self.introspect_token(token, scope) + raise ValueError("One of jwks or introspection must be configured") + + async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: """Validate a token using the JWT token.""" await self._refresh_jwks() @@ -253,7 +194,7 @@ class OAuth2TokenAuthProvider(AuthProvider): key_data, algorithms=[algorithm], audience=self.config.audience, - options={"verify_exp": True}, + issuer=self.config.issuer, ) except Exception as exc: raise ValueError(f"Invalid JWT token: {token}") from exc @@ -267,21 +208,84 @@ class OAuth2TokenAuthProvider(AuthProvider): access_attributes=access_attributes, ) + async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: + """Validate a token using token introspection as defined by RFC 7662.""" + form = { + "token": token, + } + if self.config.introspection is None: + raise ValueError("Introspection is not configured") + + if self.config.introspection.send_secret_in_body: + form["client_id"] = self.config.introspection.client_id + form["client_secret"] = self.config.introspection.client_secret + auth = None + else: + auth = (self.config.introspection.client_id, self.config.introspection.client_secret) + ssl_ctxt = None + if self.config.tls_cafile: + ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix()) + try: + async with httpx.AsyncClient(verify=ssl_ctxt) as client: + response = await client.post( + self.config.introspection.url, + data=form, + auth=auth, + timeout=10.0, # Add a reasonable timeout + ) + if response.status_code != 200: + logger.warning(f"Token introspection failed with status code: {response.status_code}") + raise ValueError(f"Token introspection failed: {response.status_code}") + + fields = response.json() + if not fields["active"]: + raise ValueError("Token not active") + principal = fields["sub"] or fields["username"] + access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping) + return TokenValidationResult( + principal=principal, + access_attributes=access_attributes, + ) + except httpx.TimeoutException: + logger.exception("Token introspection request timed out") + raise + except ValueError: + # Re-raise ValueError exceptions to preserve their message + raise + except Exception as e: + logger.exception("Error during token introspection") + raise ValueError("Token introspection error") from e + async def close(self): - """Close the HTTP client.""" + pass async def _refresh_jwks(self) -> None: - if time.time() - self._jwks_at > self.config.cache_ttl: - async with httpx.AsyncClient() as client: - res = await client.get(self.config.jwks_uri, timeout=5) - res.raise_for_status() - jwks_data = res.json()["keys"] - self._jwks = {} - for k in jwks_data: - kid = k["kid"] - # Store the entire key object as it may be needed for different algorithms - self._jwks[kid] = k - self._jwks_at = time.time() + """ + Refresh the JWKS cache. + + This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`). + If the cache is expired, we refresh the JWKS from the JWKS URI. + + Notes: for Kubernetes which doesn't fully implement the OIDC protocol: + * It doesn't have user authentication flows + * It doesn't have refresh tokens + """ + async with self._jwks_lock: + if self.config.jwks is None: + raise ValueError("JWKS is not configured") + if time.time() - self._jwks_at > self.config.jwks.key_recheck_period: + verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls + async with httpx.AsyncClient(verify=verify) as client: + res = await client.get(self.config.jwks.uri, timeout=5) + res.raise_for_status() + jwks_data = res.json()["keys"] + updated = {} + for k in jwks_data: + kid = k["kid"] + # Store the entire key object as it may be needed for different algorithms + updated[kid] = k + self._jwks = updated + self._jwks_at = time.time() class CustomAuthProviderConfig(BaseModel): @@ -359,13 +363,11 @@ class CustomAuthProvider(AuthProvider): self._client = None -def create_auth_provider(config: AuthProviderConfig) -> AuthProvider: +def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: """Factory function to create the appropriate auth provider.""" provider_type = config.provider_type.lower() - if provider_type == "kubernetes": - return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config)) - elif provider_type == "custom": + if provider_type == "custom": return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) elif provider_type == "oauth2_token": return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config)) diff --git a/llama_stack/distribution/server/quota.py b/llama_stack/distribution/server/quota.py new file mode 100644 index 000000000..ddbffae64 --- /dev/null +++ b/llama_stack/distribution/server/quota.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import time +from datetime import datetime, timedelta, timezone + +from starlette.types import ASGIApp, Receive, Scope, Send + +from llama_stack.log import get_logger +from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl + +logger = get_logger(name=__name__, category="quota") + + +class QuotaMiddleware: + """ + ASGI middleware that enforces separate quotas for authenticated and anonymous clients + within a configurable time window. + + - For authenticated requests, it reads the client ID from the + `Authorization: Bearer ` header. + - For anonymous requests, it falls back to the IP address of the client. + Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned + once a client exceeds its quota. + """ + + def __init__( + self, + app: ASGIApp, + kv_config: KVStoreConfig, + anonymous_max_requests: int, + authenticated_max_requests: int, + window_seconds: int = 86400, + ): + self.app = app + self.kv_config = kv_config + self.kv: KVStore | None = None + self.anonymous_max_requests = anonymous_max_requests + self.authenticated_max_requests = authenticated_max_requests + self.window_seconds = window_seconds + + if isinstance(self.kv_config, SqliteKVStoreConfig): + logger.warning( + "QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. " + f"window_seconds={self.window_seconds}" + ) + + async def _get_kv(self) -> KVStore: + if self.kv is None: + self.kv = await kvstore_impl(self.kv_config) + return self.kv + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] == "http": + # pick key & limit based on auth + auth_id = scope.get("authenticated_client_id") + if auth_id: + key_id = auth_id + limit = self.authenticated_max_requests + else: + # fallback to IP + client = scope.get("client") + key_id = client[0] if client else "anonymous" + limit = self.anonymous_max_requests + + current_window = int(time.time() // self.window_seconds) + key = f"quota:{key_id}:{current_window}" + + try: + kv = await self._get_kv() + prev = await kv.get(key) or "0" + count = int(prev) + 1 + + if int(prev) == 0: + # Set with expiration datetime when it is the first request in the window. + expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds) + await kv.set(key, str(count), expiration=expiration) + else: + await kv.set(key, str(count)) + except Exception: + logger.exception("Failed to access KV store for quota") + return await self._send_error(send, 500, "Quota service error") + + if count > limit: + logger.warning( + "Quota exceeded for client %s: %d/%d", + key_id, + count, + limit, + ) + return await self._send_error(send, 429, "Quota exceeded") + + return await self.app(scope, receive, send) + + async def _send_error(self, send: Send, status: int, message: str): + await send( + { + "type": "http.response.start", + "status": status, + "headers": [[b"content-type", b"application/json"]], + } + ) + body = json.dumps({"error": {"message": message}}).encode() + await send({"type": "http.response.body", "body": body}) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 15a8058ae..d70f06691 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -23,11 +23,12 @@ import yaml from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError -from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig +from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import ( PROVIDER_DATA_VAR, @@ -60,6 +61,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( from .auth import AuthenticationMiddleware from .endpoints import get_all_api_endpoints +from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -120,6 +122,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") + elif isinstance(exc, AuthenticationRequiredError): + return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}") else: return HTTPException( status_code=500, @@ -280,7 +284,18 @@ class TracingMiddleware: logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI") return await self.app(scope, receive, send) - trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path}) + trace_attributes = {"__location__": "server", "raw_path": path} + + # Extract W3C trace context headers and store as trace attributes + headers = dict(scope.get("headers", [])) + traceparent = headers.get(b"traceparent", b"").decode() + if traceparent: + trace_attributes["traceparent"] = traceparent + tracestate = headers.get(b"tracestate", b"").decode() + if tracestate: + trace_attributes["tracestate"] = tracestate + + trace_context = await start_trace(trace_path, trace_attributes) async def send_with_trace_id(message): if message["type"] == "http.response.start": @@ -370,14 +385,6 @@ def main(args: argparse.Namespace | None = None): if args is None: args = parser.parse_args() - # Check for deprecated argument usage - if "--config" in sys.argv: - warnings.warn( - "The '--config' argument is deprecated and will be removed in a future version. Use '--config' instead.", - DeprecationWarning, - stacklevel=2, - ) - log_line = "" if args.config: # if the user provided a config file, use it, even if template was specified @@ -431,6 +438,46 @@ def main(args: argparse.Namespace | None = None): if config.server.auth: logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}") app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) + else: + if config.server.quota: + quota = config.server.quota + logger.warning( + "Configured authenticated_max_requests (%d) but no auth is enabled; " + "falling back to anonymous_max_requests (%d) for all the requests", + quota.authenticated_max_requests, + quota.anonymous_max_requests, + ) + + if config.server.quota: + logger.info("Enabling quota middleware for authenticated and anonymous clients") + + quota = config.server.quota + anonymous_max_requests = quota.anonymous_max_requests + # if auth is disabled, use the anonymous max requests + authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests + + kv_config = quota.kvstore + window_map = {"day": 86400} + window_seconds = window_map[quota.period.value] + + app.add_middleware( + QuotaMiddleware, + kv_config=kv_config, + anonymous_max_requests=anonymous_max_requests, + authenticated_max_requests=authenticated_max_requests, + window_seconds=window_seconds, + ) + + # --- CORS middleware for local development --- + # TODO: move to reverse proxy + ui_port = os.environ.get("LLAMA_STACK_UI_PORT", 8322) + app.add_middleware( + CORSMiddleware, + allow_origins=[f"http://localhost:{ui_port}"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) try: impls = asyncio.run(construct_stack(config)) diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index a6b400136..0e84854c2 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -36,7 +36,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v8" +KEY_VERSION = "v9" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/distribution/utils/exec.py b/llama_stack/distribution/utils/exec.py index 4acce4f5b..7c2e00524 100644 --- a/llama_stack/distribution/utils/exec.py +++ b/llama_stack/distribution/utils/exec.py @@ -8,6 +8,7 @@ import logging import os import signal import subprocess +import sys from termcolor import cprint @@ -33,6 +34,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list: cprint( "No current conda environment detected, please specify a conda environment name with --image-name", color="red", + file=sys.stderr, ) return @@ -49,12 +51,13 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list: return envpath return None - print(f"Using conda environment: {env_name}") + cprint(f"Using conda environment: {env_name}", color="green", file=sys.stderr) conda_prefix = get_conda_prefix(env_name) if not conda_prefix: cprint( f"Conda environment {env_name} does not exist.", color="red", + file=sys.stderr, ) return @@ -63,6 +66,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list: cprint( f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name", color="red", + file=sys.stderr, ) return else: @@ -73,9 +77,10 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list: cprint( "No current virtual environment detected, please specify a virtual environment name with --image-name", color="red", + file=sys.stderr, ) return - print(f"Using virtual environment: {env_name}") + cprint(f"Using virtual environment: {env_name}", file=sys.stderr) script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh" run_args = [ diff --git a/llama_stack/log.py b/llama_stack/log.py index 98858d208..f4184710a 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -6,6 +6,7 @@ import logging import os +import sys from logging.config import dictConfig from rich.console import Console @@ -234,7 +235,7 @@ def get_logger( env_config = os.environ.get("LLAMA_STACK_LOGGING", "") if env_config: - cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow") + cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", color="yellow", file=sys.stderr) _category_levels.update(parse_environment_config(env_config)) log_file = os.environ.get("LLAMA_STACK_LOG_FILE") diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py index c6d618818..fe7be5ea9 100644 --- a/llama_stack/models/llama/llama3/generation.py +++ b/llama_stack/models/llama/llama3/generation.py @@ -174,6 +174,7 @@ class Llama3: cprint( "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", "red", + file=sys.stderr, ) prompt_tokens = [inp.tokens for inp in llm_inputs] @@ -184,7 +185,11 @@ class Llama3: max_prompt_len = max(len(t) for t in prompt_tokens) if max_prompt_len >= params.max_seq_len: - cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red") + cprint( + f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", + color="red", + file=sys.stderr, + ) return total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) diff --git a/llama_stack/models/llama/llama4/generation.py b/llama_stack/models/llama/llama4/generation.py index 476761209..6132d25d4 100644 --- a/llama_stack/models/llama/llama4/generation.py +++ b/llama_stack/models/llama/llama4/generation.py @@ -133,9 +133,9 @@ class Llama4: print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" if print_model_input: - cprint("Input to model:\n", "yellow") + cprint("Input to model:\n", color="yellow", file=sys.stderr) for inp in llm_inputs: - cprint(self.tokenizer.decode(inp.tokens), "grey") + cprint(self.tokenizer.decode(inp.tokens), color="grey", file=sys.stderr) prompt_tokens = [inp.tokens for inp in llm_inputs] bsz = len(llm_inputs) @@ -145,7 +145,7 @@ class Llama4: max_prompt_len = max(len(t) for t in prompt_tokens) if max_prompt_len >= params.max_seq_len: - cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red") + cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", color="red", file=sys.stderr) return total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 3e9806f23..60b05545b 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield -from llama_stack.apis.tools import Tool +from llama_stack.apis.tools import ToolGroup from llama_stack.apis.vector_dbs import VectorDB from llama_stack.schema_utils import json_schema_type @@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol): async def register_benchmark(self, benchmark: Benchmark) -> None: ... -class ToolsProtocolPrivate(Protocol): - async def register_tool(self, tool: Tool) -> None: ... +class ToolGroupsProtocolPrivate(Protocol): + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ... - async def unregister_tool(self, tool_id: str) -> None: ... + async def unregister_toolgroup(self, toolgroup_id: str) -> None: ... @json_schema_type diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 86780fd61..bcbfcbe31 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -20,9 +20,12 @@ from llama_stack.apis.agents import ( AgentTurnCreateRequest, AgentTurnResumeRequest, Document, + ListOpenAIResponseInputItem, + ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, + Order, Session, Turn, ) @@ -39,6 +42,7 @@ from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records +from llama_stack.providers.utils.responses.responses_store import ResponsesStore from .agent_instance import ChatAgent from .config import MetaReferenceAgentsImplConfig @@ -66,15 +70,17 @@ class MetaReferenceAgentsImpl(Agents): self.tool_groups_api = tool_groups_api self.in_memory_store = InmemoryKVStoreImpl() - self.openai_responses_impl = None + self.openai_responses_impl: OpenAIResponsesImpl | None = None async def initialize(self) -> None: self.persistence_store = await kvstore_impl(self.config.persistence_store) + self.responses_store = ResponsesStore(self.config.responses_store) + await self.responses_store.initialize() self.openai_responses_impl = OpenAIResponsesImpl( - self.persistence_store, inference_api=self.inference_api, tool_groups_api=self.tool_groups_api, tool_runtime_api=self.tool_runtime_api, + responses_store=self.responses_store, ) async def create_agent( @@ -305,14 +311,15 @@ class MetaReferenceAgentsImpl(Agents): # OpenAI responses async def get_openai_response( self, - id: str, + response_id: str, ) -> OpenAIResponseObject: - return await self.openai_responses_impl.get_openai_response(id) + return await self.openai_responses_impl.get_openai_response(response_id) async def create_openai_response( self, input: str | list[OpenAIResponseInput], model: str, + instructions: str | None = None, previous_response_id: str | None = None, store: bool | None = True, stream: bool | None = False, @@ -320,5 +327,27 @@ class MetaReferenceAgentsImpl(Agents): tools: list[OpenAIResponseInputTool] | None = None, ) -> OpenAIResponseObject: return await self.openai_responses_impl.create_openai_response( - input, model, previous_response_id, store, stream, temperature, tools + input, model, instructions, previous_response_id, store, stream, temperature, tools + ) + + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + return await self.openai_responses_impl.list_openai_responses(after, limit, model, order) + + async def list_openai_response_input_items( + self, + response_id: str, + after: str | None = None, + before: str | None = None, + include: list[str] | None = None, + limit: int | None = 20, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseInputItem: + return await self.openai_responses_impl.list_openai_response_input_items( + response_id, after, before, include, limit, order ) diff --git a/llama_stack/providers/inline/agents/meta_reference/config.py b/llama_stack/providers/inline/agents/meta_reference/config.py index c860e6df1..1c392f29c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/config.py +++ b/llama_stack/providers/inline/agents/meta_reference/config.py @@ -10,10 +10,12 @@ from pydantic import BaseModel from llama_stack.providers.utils.kvstore import KVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig class MetaReferenceAgentsImplConfig(BaseModel): persistence_store: KVStoreConfig + responses_store: SqlStoreConfig @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: @@ -21,5 +23,9 @@ class MetaReferenceAgentsImplConfig(BaseModel): "persistence_store": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, db_name="agents_store.db", - ) + ), + "responses_store": SqliteSqlStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="responses_store.db", + ), } diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 6d9d06109..3a56d41ef 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import json +import time import uuid from collections.abc import AsyncIterator from typing import Any, cast @@ -12,24 +13,29 @@ from typing import Any, cast from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel +from llama_stack.apis.agents import Order from llama_stack.apis.agents.openai_responses import ( + AllowedToolsFilter, + ListOpenAIResponseInputItem, + ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputFunctionToolCallOutput, - OpenAIResponseInputItemList, OpenAIResponseInputMessageContent, OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, OpenAIResponseInputTool, - OpenAIResponseInputToolFunction, + OpenAIResponseInputToolMCP, OpenAIResponseMessage, OpenAIResponseObject, OpenAIResponseObjectStream, OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCreated, + OpenAIResponseObjectStreamResponseOutputTextDelta, OpenAIResponseOutput, OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageWebSearchToolCall, ) from llama_stack.apis.inference.inference import ( @@ -49,11 +55,12 @@ from llama_stack.apis.inference.inference import ( OpenAIToolMessageParam, OpenAIUserMessageParam, ) -from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool -from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools logger = get_logger(name=__name__, category="openai_responses") @@ -162,41 +169,43 @@ async def _get_message_type_by_role(role: str): class OpenAIResponsePreviousResponseWithInputItems(BaseModel): - input_items: OpenAIResponseInputItemList + input_items: ListOpenAIResponseInputItem response: OpenAIResponseObject +class ChatCompletionContext(BaseModel): + model: str + messages: list[OpenAIMessageParam] + tools: list[ChatCompletionToolParam] | None = None + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] + stream: bool + temperature: float | None + + class OpenAIResponsesImpl: def __init__( self, - persistence_store: KVStore, inference_api: Inference, tool_groups_api: ToolGroups, tool_runtime_api: ToolRuntime, + responses_store: ResponsesStore, ): - self.persistence_store = persistence_store self.inference_api = inference_api self.tool_groups_api = tool_groups_api self.tool_runtime_api = tool_runtime_api - - async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems: - key = f"{OPENAI_RESPONSES_PREFIX}{id}" - response_json = await self.persistence_store.get(key=key) - if response_json is None: - raise ValueError(f"OpenAI response with id '{id}' not found") - return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json) + self.responses_store = responses_store async def _prepend_previous_response( self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None ): if previous_response_id: - previous_response_with_input = await self._get_previous_response_with_input(previous_response_id) + previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) # previous response input items - new_input_items = previous_response_with_input.input_items.data + new_input_items = previous_response_with_input.input # previous response output items - new_input_items.extend(previous_response_with_input.response.output) + new_input_items.extend(previous_response_with_input.output) # new input items from the current request if isinstance(input, str): @@ -208,99 +217,60 @@ class OpenAIResponsesImpl: return input + async def _prepend_instructions(self, messages, instructions): + if instructions: + messages.insert(0, OpenAISystemMessageParam(content=instructions)) + async def get_openai_response( self, - id: str, + response_id: str, ) -> OpenAIResponseObject: - response_with_input = await self._get_previous_response_with_input(id) - return response_with_input.response + response_with_input = await self.responses_store.get_response_object(response_id) + return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) - async def create_openai_response( + async def list_openai_responses( self, - input: str | list[OpenAIResponseInput], - model: str, - previous_response_id: str | None = None, - store: bool | None = True, - stream: bool | None = False, - temperature: float | None = None, - tools: list[OpenAIResponseInputTool] | None = None, - ): - stream = False if stream is None else stream + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + return await self.responses_store.list_responses(after, limit, model, order) - input = await self._prepend_previous_response(input, previous_response_id) - messages = await _convert_response_input_to_chat_messages(input) - chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None - chat_response = await self.inference_api.openai_chat_completion( - model=model, - messages=messages, - tools=chat_tools, - stream=stream, - temperature=temperature, - ) + async def list_openai_response_input_items( + self, + response_id: str, + after: str | None = None, + before: str | None = None, + include: list[str] | None = None, + limit: int | None = 20, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseInputItem: + """List input items for a given OpenAI response. - if stream: - # TODO: refactor this into a separate method that handles streaming - chat_response_id = "" - chat_response_content = [] - chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} - # TODO: these chunk_ fields are hacky and only take the last chunk into account - chunk_created = 0 - chunk_model = "" - chunk_finish_reason = "" - async for chunk in chat_response: - chat_response_id = chunk.id - chunk_created = chunk.created - chunk_model = chunk.model - for chunk_choice in chunk.choices: - # TODO: this only works for text content - chat_response_content.append(chunk_choice.delta.content or "") - if chunk_choice.finish_reason: - chunk_finish_reason = chunk_choice.finish_reason - - # Aggregate tool call arguments across chunks, using their index as the aggregation key - if chunk_choice.delta.tool_calls: - for tool_call in chunk_choice.delta.tool_calls: - response_tool_call = chat_response_tool_calls.get(tool_call.index, None) - if response_tool_call: - response_tool_call.function.arguments += tool_call.function.arguments - else: - tool_call_dict: dict[str, Any] = tool_call.model_dump() - # Ensure we don't have any empty type field in the tool call dict. - # The OpenAI client used by providers often returns a type=None here. - tool_call_dict.pop("type", None) - response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) - chat_response_tool_calls[tool_call.index] = response_tool_call - - # Convert the dict of tool calls by index to a list of tool calls to pass back in our response - if chat_response_tool_calls: - tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] - else: - tool_calls = None - assistant_message = OpenAIAssistantMessageParam( - content="".join(chat_response_content), - tool_calls=tool_calls, - ) - chat_response = OpenAIChatCompletion( - id=chat_response_id, - choices=[ - OpenAIChoice( - message=assistant_message, - finish_reason=chunk_finish_reason, - index=0, - ) - ], - created=chunk_created, - model=chunk_model, - ) - else: - # dump and reload to map to our pydantic types - chat_response = OpenAIChatCompletion(**chat_response.model_dump()) + :param response_id: The ID of the response to retrieve input items for. + :param after: An item ID to list items after, used for pagination. + :param before: An item ID to list items before, used for pagination. + :param include: Additional fields to include in the response. + :param limit: A limit on the number of objects to be returned. + :param order: The order to return the input items in. + :returns: An ListOpenAIResponseInputItem. + """ + return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) + async def _process_response_choices( + self, + chat_response: OpenAIChatCompletion, + ctx: ChatCompletionContext, + tools: list[OpenAIResponseInputTool] | None, + ) -> list[OpenAIResponseOutput]: + """Handle tool execution and response message creation.""" output_messages: list[OpenAIResponseOutput] = [] + # Execute tool calls if any for choice in chat_response.choices: if choice.message.tool_calls and tools: # Assume if the first tool is a function, all tools are functions - if isinstance(tools[0], OpenAIResponseInputToolFunction): + if tools[0].type == "function": for tool_call in choice.message.tool_calls: output_messages.append( OpenAIResponseOutputMessageFunctionToolCall( @@ -312,11 +282,133 @@ class OpenAIResponsesImpl: ) ) else: - output_messages.extend( - await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature) - ) + tool_messages = await self._execute_tool_and_return_final_output(choice, ctx) + output_messages.extend(tool_messages) else: output_messages.append(await _convert_chat_choice_to_response_message(choice)) + + return output_messages + + async def _store_response( + self, + response: OpenAIResponseObject, + original_input: str | list[OpenAIResponseInput], + ) -> None: + new_input_id = f"msg_{uuid.uuid4()}" + if isinstance(original_input, str): + # synthesize a message from the input string + input_content = OpenAIResponseInputMessageContentText(text=original_input) + input_content_item = OpenAIResponseMessage( + role="user", + content=[input_content], + id=new_input_id, + ) + input_items_data = [input_content_item] + else: + # we already have a list of messages + input_items_data = [] + for input_item in original_input: + if isinstance(input_item, OpenAIResponseMessage): + # These may or may not already have an id, so dump to dict, check for id, and add if missing + input_item_dict = input_item.model_dump() + if "id" not in input_item_dict: + input_item_dict["id"] = new_input_id + input_items_data.append(OpenAIResponseMessage(**input_item_dict)) + else: + input_items_data.append(input_item) + + await self.responses_store.store_response_object( + response_object=response, + input=input_items_data, + ) + + async def create_openai_response( + self, + input: str | list[OpenAIResponseInput], + model: str, + instructions: str | None = None, + previous_response_id: str | None = None, + store: bool | None = True, + stream: bool | None = False, + temperature: float | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + ): + stream = False if stream is None else stream + original_input = input # Keep reference for storage + + output_messages: list[OpenAIResponseOutput] = [] + + # Input preprocessing + input = await self._prepend_previous_response(input, previous_response_id) + messages = await _convert_response_input_to_chat_messages(input) + await self._prepend_instructions(messages, instructions) + + # Tool setup + chat_tools, mcp_tool_to_server, mcp_list_message = ( + await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None) + ) + if mcp_list_message: + output_messages.append(mcp_list_message) + + ctx = ChatCompletionContext( + model=model, + messages=messages, + tools=chat_tools, + mcp_tool_to_server=mcp_tool_to_server, + stream=stream, + temperature=temperature, + ) + + inference_result = await self.inference_api.openai_chat_completion( + model=model, + messages=messages, + tools=chat_tools, + stream=stream, + temperature=temperature, + ) + + if stream: + return self._create_streaming_response( + inference_result=inference_result, + ctx=ctx, + output_messages=output_messages, + original_input=original_input, + model=model, + store=store, + tools=tools, + ) + else: + return await self._create_non_streaming_response( + inference_result=inference_result, + ctx=ctx, + output_messages=output_messages, + original_input=original_input, + model=model, + store=store, + tools=tools, + ) + + async def _create_non_streaming_response( + self, + inference_result: Any, + ctx: ChatCompletionContext, + output_messages: list[OpenAIResponseOutput], + original_input: str | list[OpenAIResponseInput], + model: str, + store: bool | None, + tools: list[OpenAIResponseInputTool] | None, + ) -> OpenAIResponseObject: + chat_response = OpenAIChatCompletion(**inference_result.model_dump()) + + # Process response choices (tool execution and message creation) + output_messages.extend( + await self._process_response_choices( + chat_response=chat_response, + ctx=ctx, + tools=tools, + ) + ) + response = OpenAIResponseObject( created_at=chat_response.created, id=f"resp-{uuid.uuid4()}", @@ -327,57 +419,168 @@ class OpenAIResponsesImpl: ) logger.debug(f"OpenAI Responses response: {response}") + # Store response if requested if store: - # Store in kvstore - - new_input_id = f"msg_{uuid.uuid4()}" - if isinstance(input, str): - # synthesize a message from the input string - input_content = OpenAIResponseInputMessageContentText(text=input) - input_content_item = OpenAIResponseMessage( - role="user", - content=[input_content], - id=new_input_id, - ) - input_items_data = [input_content_item] - else: - # we already have a list of messages - input_items_data = [] - for input_item in input: - if isinstance(input_item, OpenAIResponseMessage): - # These may or may not already have an id, so dump to dict, check for id, and add if missing - input_item_dict = input_item.model_dump() - if "id" not in input_item_dict: - input_item_dict["id"] = new_input_id - input_items_data.append(OpenAIResponseMessage(**input_item_dict)) - else: - input_items_data.append(input_item) - - input_items = OpenAIResponseInputItemList(data=input_items_data) - prev_response = OpenAIResponsePreviousResponseWithInputItems( - input_items=input_items, + await self._store_response( response=response, + original_input=original_input, ) - key = f"{OPENAI_RESPONSES_PREFIX}{response.id}" - await self.persistence_store.set( - key=key, - value=prev_response.model_dump_json(), - ) - - if stream: - - async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]: - # TODO: response created should actually get emitted much earlier in the process - yield OpenAIResponseObjectStreamResponseCreated(response=response) - yield OpenAIResponseObjectStreamResponseCompleted(response=response) - - return async_response() return response + async def _create_streaming_response( + self, + inference_result: Any, + ctx: ChatCompletionContext, + output_messages: list[OpenAIResponseOutput], + original_input: str | list[OpenAIResponseInput], + model: str, + store: bool | None, + tools: list[OpenAIResponseInputTool] | None, + ) -> AsyncIterator[OpenAIResponseObjectStream]: + # Create initial response and emit response.created immediately + response_id = f"resp-{uuid.uuid4()}" + created_at = int(time.time()) + + initial_response = OpenAIResponseObject( + created_at=created_at, + id=response_id, + model=model, + object="response", + status="in_progress", + output=output_messages.copy(), + ) + + # Emit response.created immediately + yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) + + # For streaming, inference_result is an async iterator of chunks + # Stream chunks and emit delta events as they arrive + chat_response_id = "" + chat_response_content = [] + chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} + chunk_created = 0 + chunk_model = "" + chunk_finish_reason = "" + sequence_number = 0 + + # Create a placeholder message item for delta events + message_item_id = f"msg_{uuid.uuid4()}" + + async for chunk in inference_result: + chat_response_id = chunk.id + chunk_created = chunk.created + chunk_model = chunk.model + for chunk_choice in chunk.choices: + # Emit incremental text content as delta events + if chunk_choice.delta.content: + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputTextDelta( + content_index=0, + delta=chunk_choice.delta.content, + item_id=message_item_id, + output_index=0, + sequence_number=sequence_number, + ) + + # Collect content for final response + chat_response_content.append(chunk_choice.delta.content or "") + if chunk_choice.finish_reason: + chunk_finish_reason = chunk_choice.finish_reason + + # Aggregate tool call arguments across chunks, using their index as the aggregation key + if chunk_choice.delta.tool_calls: + for tool_call in chunk_choice.delta.tool_calls: + response_tool_call = chat_response_tool_calls.get(tool_call.index, None) + if response_tool_call: + response_tool_call.function.arguments += tool_call.function.arguments + else: + tool_call_dict: dict[str, Any] = tool_call.model_dump() + tool_call_dict.pop("type", None) + response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) + chat_response_tool_calls[tool_call.index] = response_tool_call + + # Convert collected chunks to complete response + if chat_response_tool_calls: + tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] + else: + tool_calls = None + assistant_message = OpenAIAssistantMessageParam( + content="".join(chat_response_content), + tool_calls=tool_calls, + ) + chat_response_obj = OpenAIChatCompletion( + id=chat_response_id, + choices=[ + OpenAIChoice( + message=assistant_message, + finish_reason=chunk_finish_reason, + index=0, + ) + ], + created=chunk_created, + model=chunk_model, + ) + + # Process response choices (tool execution and message creation) + output_messages.extend( + await self._process_response_choices( + chat_response=chat_response_obj, + ctx=ctx, + tools=tools, + ) + ) + + # Create final response + final_response = OpenAIResponseObject( + created_at=created_at, + id=response_id, + model=model, + object="response", + status="completed", + output=output_messages, + ) + + if store: + await self._store_response( + response=final_response, + original_input=original_input, + ) + + # Emit response.completed + yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) + async def _convert_response_tools_to_chat_tools( self, tools: list[OpenAIResponseInputTool] - ) -> list[ChatCompletionToolParam]: + ) -> tuple[ + list[ChatCompletionToolParam], + dict[str, OpenAIResponseInputToolMCP], + OpenAIResponseOutput | None, + ]: + from llama_stack.apis.agents.openai_responses import ( + MCPListToolsTool, + ) + from llama_stack.apis.tools.tools import Tool + + mcp_tool_to_server = {} + + def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: + tool_def = ToolDefinition( + tool_name=tool_name, + description=tool.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool.parameters + }, + ) + return convert_tooldef_to_openai_tool(tool_def) + + mcp_list_message = None chat_tools: list[ChatCompletionToolParam] = [] for input_tool in tools: # TODO: Handle other tool types @@ -386,91 +589,95 @@ class OpenAIResponsesImpl: elif input_tool.type == "web_search": tool_name = "web_search" tool = await self.tool_groups_api.get_tool(tool_name) - tool_def = ToolDefinition( - tool_name=tool_name, - description=tool.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool.parameters - }, + if not tool: + raise ValueError(f"Tool {tool_name} not found") + chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "mcp": + always_allowed = None + never_allowed = None + if input_tool.allowed_tools: + if isinstance(input_tool.allowed_tools, list): + always_allowed = input_tool.allowed_tools + elif isinstance(input_tool.allowed_tools, AllowedToolsFilter): + always_allowed = input_tool.allowed_tools.always + never_allowed = input_tool.allowed_tools.never + + tool_defs = await list_mcp_tools( + endpoint=input_tool.server_url, + headers=input_tool.headers or {}, ) - chat_tool = convert_tooldef_to_openai_tool(tool_def) - chat_tools.append(chat_tool) + + mcp_list_message = OpenAIResponseOutputMessageMCPListTools( + id=f"mcp_list_{uuid.uuid4()}", + status="completed", + server_label=input_tool.server_label, + tools=[], + ) + for t in tool_defs.data: + if never_allowed and t.name in never_allowed: + continue + if not always_allowed or t.name in always_allowed: + chat_tools.append(make_openai_tool(t.name, t)) + if t.name in mcp_tool_to_server: + raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}") + mcp_tool_to_server[t.name] = input_tool + mcp_list_message.tools.append( + MCPListToolsTool( + name=t.name, + description=t.description, + input_schema={ + "type": "object", + "properties": { + p.name: { + "type": p.parameter_type, + "description": p.description, + } + for p in t.parameters + }, + "required": [p.name for p in t.parameters if p.required], + }, + ) + ) else: raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") - return chat_tools + return chat_tools, mcp_tool_to_server, mcp_list_message async def _execute_tool_and_return_final_output( self, - model_id: str, - stream: bool, choice: OpenAIChoice, - messages: list[OpenAIMessageParam], - temperature: float, + ctx: ChatCompletionContext, ) -> list[OpenAIResponseOutput]: output_messages: list[OpenAIResponseOutput] = [] - # If the choice is not an assistant message, we don't need to execute any tools if not isinstance(choice.message, OpenAIAssistantMessageParam): return output_messages - # If the assistant message doesn't have any tool calls, we don't need to execute any tools if not choice.message.tool_calls: return output_messages - # Copy the messages list to avoid mutating the original list - messages = messages.copy() + next_turn_messages = ctx.messages.copy() # Add the assistant message with tool_calls response to the messages list - messages.append(choice.message) + next_turn_messages.append(choice.message) for tool_call in choice.message.tool_calls: - tool_call_id = tool_call.id - function = tool_call.function - - # If for some reason the tool call doesn't have a function or id, we can't execute it - if not function or not tool_call_id: - continue - # TODO: telemetry spans for tool calls - result = await self._execute_tool_call(function) - - # Handle tool call failure - if not result: - output_messages.append( - OpenAIResponseOutputMessageWebSearchToolCall( - id=tool_call_id, - status="failed", - ) - ) - continue - - output_messages.append( - OpenAIResponseOutputMessageWebSearchToolCall( - id=tool_call_id, - status="completed", - ), - ) - - result_content = "" - # TODO: handle other result content types and lists - if isinstance(result.content, str): - result_content = result.content - messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id)) + tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx) + if tool_call_log: + output_messages.append(tool_call_log) + if further_input: + next_turn_messages.append(further_input) tool_results_chat_response = await self.inference_api.openai_chat_completion( - model=model_id, - messages=messages, - stream=stream, - temperature=temperature, + model=ctx.model, + messages=next_turn_messages, + stream=ctx.stream, + temperature=ctx.temperature, ) - # type cast to appease mypy + # type cast to appease mypy: this is needed because we don't handle streaming properly :) tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response) + + # Huge TODO: these are NOT the final outputs, we must keep the loop going tool_final_outputs = [ await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices ] @@ -480,15 +687,86 @@ class OpenAIResponsesImpl: async def _execute_tool_call( self, - function: OpenAIChatCompletionToolCallFunction, - ) -> ToolInvocationResult | None: - if not function.name: - return None - function_args = json.loads(function.arguments) if function.arguments else {} - logger.info(f"executing tool call: {function.name} with args: {function_args}") - result = await self.tool_runtime_api.invoke_tool( - tool_name=function.name, - kwargs=function_args, + tool_call: OpenAIChatCompletionToolCall, + ctx: ChatCompletionContext, + ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]: + from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, ) - logger.debug(f"tool call {function.name} completed with result: {result}") - return result + + tool_call_id = tool_call.id + function = tool_call.function + + if not function or not tool_call_id or not function.name: + return None, None + + error_exc = None + result = None + try: + if function.name in ctx.mcp_tool_to_server: + mcp_tool = ctx.mcp_tool_to_server[function.name] + result = await invoke_mcp_tool( + endpoint=mcp_tool.server_url, + headers=mcp_tool.headers or {}, + tool_name=function.name, + kwargs=json.loads(function.arguments) if function.arguments else {}, + ) + else: + result = await self.tool_runtime_api.invoke_tool( + tool_name=function.name, + kwargs=json.loads(function.arguments) if function.arguments else {}, + ) + except Exception as e: + error_exc = e + + if function.name in ctx.mcp_tool_to_server: + from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall + + message = OpenAIResponseOutputMessageMCPCall( + id=tool_call_id, + arguments=function.arguments, + name=function.name, + server_label=ctx.mcp_tool_to_server[function.name].server_label, + ) + if error_exc: + message.error = str(error_exc) + elif (result.error_code and result.error_code > 0) or result.error_message: + message.error = f"Error (code {result.error_code}): {result.error_message}" + elif result.content: + message.output = interleaved_content_as_str(result.content) + else: + if function.name == "web_search": + message = OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="completed", + ) + if error_exc or (result.error_code and result.error_code > 0) or result.error_message: + message.status = "failed" + else: + raise ValueError(f"Unknown tool {function.name} called") + + input_message = None + if result and result.content: + if isinstance(result.content, str): + content = result.content + elif isinstance(result.content, list): + from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem + + content = [] + for item in result.content: + if isinstance(item, TextContentItem): + part = OpenAIChatCompletionContentPartTextParam(text=item.text) + elif isinstance(item, ImageContentItem): + if item.image.data: + url = f"data:image;base64,{item.image.data}" + else: + url = item.image.url + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) + else: + raise ValueError(f"Unknown result content type: {type(item)}") + content.append(part) + else: + raise ValueError(f"Unknown result content type: {type(result.content)}") + input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) + + return message, input_message diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 048336f9e..e238e1b78 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -6,6 +6,7 @@ import asyncio import os +import sys from collections.abc import AsyncGenerator from pydantic import BaseModel @@ -455,9 +456,9 @@ class MetaReferenceInferenceImpl( first = token_results[0] if not first.finished and not first.ignore_token: if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"): - cprint(first.text, "cyan", end="") + cprint(first.text, color="cyan", end="", file=sys.stderr) if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": - cprint(f"<{first.token}>", "magenta", end="") + cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr) for result in token_results: idx = result.batch_idx @@ -519,9 +520,9 @@ class MetaReferenceInferenceImpl( for token_results in self.generator.chat_completion([request]): token_result = token_results[0] if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": - cprint(token_result.text, "cyan", end="") + cprint(token_result.text, color="cyan", end="", file=sys.stderr) if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": - cprint(f"<{token_result.token}>", "magenta", end="") + cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr) if token_result.token == tokenizer.eot_id: stop_reason = StopReason.end_of_turn diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 67362dd36..0f6cf8619 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from llama_stack.apis.telemetry import ( Event, @@ -44,6 +45,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor ) from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore +from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS from .config import TelemetryConfig, TelemetrySink @@ -146,7 +148,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): if span: timestamp_ns = int(event.timestamp.timestamp() * 1e9) span.add_event( - name=event.type, + name=event.type.value, attributes={ "message": event.message, "severity": event.severity.value, @@ -206,6 +208,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): event.attributes = {} event.attributes["__ttl__"] = ttl_seconds + # Extract these W3C trace context attributes so they are not written to + # underlying storage, as we just need them to propagate the trace context. + traceparent = event.attributes.pop("traceparent", None) + tracestate = event.attributes.pop("tracestate", None) + if traceparent: + # If we have a traceparent header value, we're not the root span. + for root_attribute in ROOT_SPAN_MARKERS: + event.attributes.pop(root_attribute, None) + if isinstance(event.payload, SpanStartPayload): # Check if span already exists to prevent duplicates if span_id in _GLOBAL_STORAGE["active_spans"]: @@ -216,8 +227,12 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): parent_span_id = int(event.payload.parent_span_id, 16) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) context = trace.set_span_in_context(parent_span) - else: - event.attributes["__root_span__"] = "true" + elif traceparent: + carrier = { + "traceparent": traceparent, + "tracestate": tracestate, + } + context = TraceContextTextMapPropagator().extract(carrier=carrier) span = tracer.start_span( name=event.payload.name, diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 39f752297..c2d264c91 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -25,14 +25,14 @@ from llama_stack.apis.tools import ( RAGQueryConfig, RAGQueryResult, RAGToolRuntime, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, @@ -49,7 +49,7 @@ def make_random_string(length: int = 8): return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) -class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): +class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime): def __init__( self, config: RagToolRuntimeConfig, @@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): async def shutdown(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return async def insert( @@ -122,6 +122,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): query=query, params={ "max_chunks": query_config.max_chunks, + "mode": query_config.mode, }, ) for vector_db_id in vector_db_ids @@ -146,7 +147,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): for i, chunk in enumerate(chunks): metadata = chunk.metadata tokens += metadata["token_count"] - tokens += metadata["metadata_token_count"] + tokens += metadata.get("metadata_token_count", 0) if tokens > query_config.max_tokens_in_context: log.error( diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index d3dc7e694..47256d88d 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -99,9 +99,13 @@ class FaissIndex(EmbeddingIndex): # Save updated index await self._save_index() - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector( + self, + embedding: NDArray, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k) - chunks = [] scores = [] for d, i in zip(distances[0], indices[0], strict=False): @@ -112,6 +116,14 @@ class FaissIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in FAISS") + class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index ab4384021..fc1a8ddb0 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -24,6 +24,11 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect logger = logging.getLogger(__name__) +# Specifying search mode is dependent on the VectorIO provider. +VECTOR_SEARCH = "vector" +KEYWORD_SEARCH = "keyword" +SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH} + def serialize_vector(vector: list[float]) -> bytes: """Serialize a list of floats into a compact binary representation.""" @@ -45,6 +50,7 @@ class SQLiteVecIndex(EmbeddingIndex): Two tables are used: - A metadata table (chunks_{bank_id}) that holds the chunk JSON. - A virtual table (vec_chunks_{bank_id}) that holds the serialized vector. + - An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search. """ def __init__(self, dimension: int, db_path: str, bank_id: str): @@ -53,6 +59,7 @@ class SQLiteVecIndex(EmbeddingIndex): self.bank_id = bank_id self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") + self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_") @classmethod async def create(cls, dimension: int, db_path: str, bank_id: str): @@ -78,6 +85,14 @@ class SQLiteVecIndex(EmbeddingIndex): USING vec0(embedding FLOAT[{self.dimension}], id TEXT); """) connection.commit() + # FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one + # based on query. Implementation of the change on client side will allow passing the search_mode option + # during initialization to make it easier to create the table that is required. + cur.execute(f""" + CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table} + USING fts5(id, content); + """) + connection.commit() finally: cur.close() connection.close() @@ -91,6 +106,7 @@ class SQLiteVecIndex(EmbeddingIndex): try: cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") + cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};") connection.commit() finally: cur.close() @@ -104,6 +120,7 @@ class SQLiteVecIndex(EmbeddingIndex): For each chunk, we insert its JSON into the metadata table and then insert its embedding (serialized to raw bytes) into the virtual table using the assigned rowid. If any insert fails, the transaction is rolled back to maintain consistency. + Also inserts chunk content into FTS table for keyword search support. """ assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks" @@ -112,18 +129,16 @@ class SQLiteVecIndex(EmbeddingIndex): cur = connection.cursor() try: - # Start transaction a single transcation for all batches cur.execute("BEGIN TRANSACTION") for i in range(0, len(chunks), batch_size): batch_chunks = chunks[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size] - # Prepare metadata inserts + + # Insert metadata metadata_data = [ (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json()) for chunk in batch_chunks - if isinstance(chunk.content, str) ] - # Insert metadata (ON CONFLICT to avoid duplicates) cur.executemany( f""" INSERT INTO {self.metadata_table} (id, chunk) @@ -132,21 +147,43 @@ class SQLiteVecIndex(EmbeddingIndex): """, metadata_data, ) - # Prepare embeddings inserts + + # Insert vector embeddings embedding_data = [ ( - generate_chunk_id(chunk.metadata["document_id"], chunk.content), - serialize_vector(emb.tolist()), + ( + generate_chunk_id(chunk.metadata["document_id"], chunk.content), + serialize_vector(emb.tolist()), + ) ) for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) - if isinstance(chunk.content, str) ] - # Insert embeddings in batch - cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data) + cur.executemany( + f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", + embedding_data, + ) + + # Insert FTS content + fts_data = [ + (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content) + for chunk in batch_chunks + ] + # DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT) + cur.executemany( + f"DELETE FROM {self.fts_table} WHERE id = ?;", + [(row[0],) for row in fts_data], + ) + + # INSERT new entries + cur.executemany( + f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);", + fts_data, + ) + connection.commit() except sqlite3.Error as e: - connection.rollback() # Rollback on failure + connection.rollback() logger.error(f"Error inserting into {self.vector_table}: {e}") raise @@ -154,22 +191,25 @@ class SQLiteVecIndex(EmbeddingIndex): cur.close() connection.close() - # Process all batches in a single thread + # Run batch insertion in a background thread await asyncio.to_thread(_execute_all_batch_inserts) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector( + self, + embedding: NDArray, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: """ - Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query - against the virtual table. The SQL joins the metadata table to recover the chunk JSON. + Performs vector-based search using a virtual table for vector similarity. """ - emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) - emb_blob = serialize_vector(emb_list) def _execute_query(): connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() - try: + emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) + emb_blob = serialize_vector(emb_list) query_sql = f""" SELECT m.id, m.chunk, v.distance FROM {self.vector_table} AS v @@ -184,17 +224,66 @@ class SQLiteVecIndex(EmbeddingIndex): connection.close() rows = await asyncio.to_thread(_execute_query) - chunks, scores = [], [] - for _id, chunk_json, distance in rows: + for row in rows: + _id, chunk_json, distance = row + score = 1.0 / distance if distance != 0 else float("inf") + if score < score_threshold: + continue + try: + chunk = Chunk.model_validate_json(chunk_json) + except Exception as e: + logger.error(f"Error parsing chunk JSON for id {_id}: {e}") + continue + chunks.append(chunk) + scores.append(score) + return QueryChunksResponse(chunks=chunks, scores=scores) + + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """ + Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search. + """ + if query_string is None: + raise ValueError("query_string is required for keyword search.") + + def _execute_query(): + connection = _create_sqlite_connection(self.db_path) + cur = connection.cursor() + try: + query_sql = f""" + SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score + FROM {self.fts_table} AS f + JOIN {self.metadata_table} AS m ON m.id = f.id + WHERE f.content MATCH ? + ORDER BY score ASC + LIMIT ?; + """ + cur.execute(query_sql, (query_string, k)) + return cur.fetchall() + finally: + cur.close() + connection.close() + + rows = await asyncio.to_thread(_execute_query) + chunks, scores = [], [] + for row in rows: + _id, chunk_json, score = row + # BM25 scores returned by sqlite-vec are NEGATED (i.e., more relevant = more negative). + # This design is intentional to simplify sorting by ascending score. + # Reference: https://alexgarcia.xyz/blog/2024/sqlite-vec-hybrid-search/index.html + if score > -score_threshold: + continue try: chunk = Chunk.model_validate_json(chunk_json) except Exception as e: logger.error(f"Error parsing chunk JSON for id {_id}: {e}") continue chunks.append(chunk) - # Mimic the Faiss scoring: score = 1/distance (avoid division by zero) - score = 1.0 / distance if distance != 0 else float("inf") scores.append(score) return QueryChunksResponse(chunks=chunks, scores=scores) diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index c209da092..e0a04be48 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -63,4 +63,14 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", ), ), + remote_provider_spec( + api=Api.safety, + adapter=AdapterSpec( + adapter_type="sambanova", + pip_packages=["litellm"], + module="llama_stack.providers.remote.safety.sambanova", + config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", + provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", + ), + ), ] diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index b9194810e..277914df2 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -80,8 +80,9 @@ def available_providers() -> list[ProviderSpec]: adapter=AdapterSpec( adapter_type="model-context-protocol", module="llama_stack.providers.remote.tool_runtime.model_context_protocol", - config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", + config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", pip_packages=["mcp"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", ), ), ] diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 9a1ec7ee0..c3c25edd3 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -92,8 +92,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): if prompt_logprobs is not None: logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") + model_id = (await self.model_store.get_model(model)).provider_resource_id + if model_id.startswith("openai/"): + model_id = model_id[len("openai/") :] params = await prepare_openai_completion_params( - model=(await self.model_store.get_model(model)).provider_resource_id, + model=model_id, prompt=prompt, best_of=best_of, echo=echo, @@ -139,8 +142,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): top_p: float | None = None, user: str | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + model_id = (await self.model_store.get_model(model)).provider_resource_id + if model_id.startswith("openai/"): + model_id = model_id[len("openai/") :] params = await prepare_openai_completion_params( - model=(await self.model_store.get_model(model)).provider_resource_id, + model=model_id, messages=messages, frequency_penalty=frequency_penalty, function_call=function_call, diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 8530594b6..99abddf51 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -4,8 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pathlib import Path -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from llama_stack.schema_utils import json_schema_type @@ -24,11 +25,27 @@ class VLLMInferenceAdapterConfig(BaseModel): default="fake", description="The API token", ) - tls_verify: bool = Field( + tls_verify: bool | str = Field( default=True, - description="Whether to verify TLS certificates", + description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.", ) + @field_validator("tls_verify") + @classmethod + def validate_tls_verify(cls, v): + if isinstance(v, str): + # Check if it's a boolean string + if v.lower() in ("true", "false"): + return v.lower() == "true" + # Otherwise, treat it as a cert path + cert_path = Path(v).expanduser().resolve() + if not cert_path.exists(): + raise ValueError(f"TLS certificate file does not exist: {v}") + if not cert_path.is_file(): + raise ValueError(f"TLS certificate path is not a file: {v}") + return v + return v + @classmethod def sample_run_config( cls, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d00218dd5..fe2d8bec1 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -313,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): return AsyncOpenAI( base_url=self.config.url, api_key=self.config.api_token, - http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False), + http_client=httpx.AsyncClient(verify=self.config.tls_verify), ) async def completion( diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index 409818cb3..d839ffd6f 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -224,7 +224,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): Parameters: training_config: TrainingConfig - Configuration for training - model: str - Model identifier + model: str - NeMo Customizer configuration name algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm job_uuid: str - Unique identifier for the job, ignored atm @@ -299,9 +299,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): User is informed about unsupported parameters via warnings. """ - # Map model to nvidia model name - # See `_MODEL_ENTRIES` for supported models - nvidia_model = self.get_provider_model_id(model) # Check for unsupported method parameters unsupported_method_params = [] @@ -347,7 +344,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): # Prepare base job configuration job_config = { - "config": nvidia_model, + "config": model, "dataset": { "name": training_config["data_config"]["dataset_id"], "namespace": self.config.dataset_namespace, diff --git a/llama_stack/providers/remote/safety/sambanova/__init__.py b/llama_stack/providers/remote/safety/sambanova/__init__.py new file mode 100644 index 000000000..bb9d15374 --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import Any + +from .config import SambaNovaSafetyConfig + + +async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any: + from .sambanova import SambaNovaSafetyAdapter + + impl = SambaNovaSafetyAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/safety/sambanova/config.py b/llama_stack/providers/remote/safety/sambanova/config.py new file mode 100644 index 000000000..383cea244 --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/config.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field, SecretStr + +from llama_stack.schema_utils import json_schema_type + + +class SambaNovaProviderDataValidator(BaseModel): + sambanova_api_key: str | None = Field( + default=None, + description="Sambanova Cloud API key", + ) + + +@json_schema_type +class SambaNovaSafetyConfig(BaseModel): + url: str = Field( + default="https://api.sambanova.ai/v1", + description="The URL for the SambaNova AI server", + ) + api_key: SecretStr | None = Field( + default=None, + description="The SambaNova cloud API Key", + ) + + @classmethod + def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]: + return { + "url": "https://api.sambanova.ai/v1", + "api_key": api_key, + } diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py new file mode 100644 index 000000000..84c8267ae --- /dev/null +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import logging +from typing import Any + +import litellm +import requests + +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new + +from .config import SambaNovaSafetyConfig + +logger = logging.getLogger(__name__) + +CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" + + +class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): + def __init__(self, config: SambaNovaSafetyConfig) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + def _get_api_key(self) -> str: + config_api_key = self.config.api_key if self.config.api_key else None + if config_api_key: + return config_api_key.get_secret_value() + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.sambanova_api_key: + raise ValueError( + 'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": }' + ) + return provider_data.sambanova_api_key + + async def register_shield(self, shield: Shield) -> None: + list_models_url = self.config.url + "/models" + try: + response = requests.get(list_models_url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Request to {list_models_url} failed") from e + available_models = [model.get("id") for model in response.json().get("data", {})] + if ( + len(available_models) == 0 + or "guard" not in shield.provider_resource_id.lower() + or shield.provider_resource_id.split("sambanova/")[-1] not in available_models + ): + raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova") + + async def run_shield( + self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None + ) -> RunShieldResponse: + shield = await self.shield_store.get_shield(shield_id) + if not shield: + raise ValueError(f"Shield {shield_id} not found") + + shield_params = shield.params + logger.debug(f"run_shield::{shield_params}::messages={messages}") + content_messages = [await convert_message_to_openai_dict_new(m) for m in messages] + logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") + + response = litellm.completion( + model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key() + ) + shield_message = response.choices[0].message.content + + if "unsafe" in shield_message.lower(): + user_message = CANNED_RESPONSE_TEXT + violation_type = shield_message.split("\n")[-1] + metadata = {"violation_type": violation_type} + + return RunShieldResponse( + violation=SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, + ) + ) + + return RunShieldResponse() diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index 18bec463f..7e82cb6d4 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -12,19 +12,19 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import BingSearchToolConfig -class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: BingSearchToolConfig): self.config = config self.url = "https://api.bing.microsoft.com/v7.0/search" @@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 355cb98b6..b96b9e59c 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -11,30 +11,30 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.models.llama.datatypes import BuiltinTool -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import BraveSearchToolConfig -class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: BraveSearchToolConfig): self.config = config async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py index fb1f558e5..051a880a7 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py @@ -4,18 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel - -from .config import ModelContextProtocolConfig +from .config import MCPProviderConfig -class ModelContextProtocolToolProviderDataValidator(BaseModel): - api_key: str - - -async def get_adapter_impl(config: ModelContextProtocolConfig, _deps): +async def get_adapter_impl(config: MCPProviderConfig, _deps): from .model_context_protocol import ModelContextProtocolToolRuntimeImpl - impl = ModelContextProtocolToolRuntimeImpl(config) + impl = ModelContextProtocolToolRuntimeImpl(config, _deps) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index d509074fc..b8c5e77fd 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -9,7 +9,12 @@ from typing import Any from pydantic import BaseModel -class ModelContextProtocolConfig(BaseModel): +class MCPProviderDataValidator(BaseModel): + # mcp_endpoint => dict of headers to send + mcp_headers: dict[str, dict[str, str]] | None = None + + +class MCPProviderConfig(BaseModel): @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 142730e89..a9b252dfe 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -7,61 +7,45 @@ from typing import Any from urllib.parse import urlparse -from mcp import ClientSession -from mcp.client.sse import sse_client - from llama_stack.apis.common.content_types import URL +from llama_stack.apis.datatypes import Api from llama_stack.apis.tools import ( ListToolDefsResponse, - ToolDef, + ToolGroup, ToolInvocationResult, - ToolParameter, ToolRuntime, ) -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate +from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools -from .config import ModelContextProtocolConfig +from .config import MCPProviderConfig + +logger = get_logger(__name__, category="tools") -class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): - def __init__(self, config: ModelContextProtocolConfig): +class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): + def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): self.config = config async def initialize(self): pass + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: + pass + + async def unregister_toolgroup(self, toolgroup_id: str) -> None: + return + async def list_runtime_tools( self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: + # this endpoint should be retrieved by getting the tool group right? if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") - - tools = [] - async with sse_client(mcp_endpoint.uri) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - tools_result = await session.list_tools() - for tool in tools_result.tools: - parameters = [] - for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): - parameters.append( - ToolParameter( - name=param_name, - parameter_type=param_schema.get("type", "string"), - description=param_schema.get("description", ""), - ) - ) - tools.append( - ToolDef( - name=tool.name, - description=tool.description, - parameters=parameters, - metadata={ - "endpoint": mcp_endpoint.uri, - }, - ) - ) - return ListToolDefsResponse(data=tools) + headers = await self.get_headers_from_request(mcp_endpoint.uri) + return await list_mcp_tools(mcp_endpoint.uri, headers) async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) @@ -71,12 +55,19 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") - async with sse_client(endpoint) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - result = await session.call_tool(tool.identifier, kwargs) + headers = await self.get_headers_from_request(endpoint) + return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs) - return ToolInvocationResult( - content="\n".join([result.model_dump_json() for result in result.content]), - error_code=1 if result.isError else 0, - ) + async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: + def canonicalize_uri(uri: str) -> str: + return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}" + + headers = {} + + provider_data = self.get_request_provider_data() + if provider_data and provider_data.mcp_headers: + for uri, values in provider_data.mcp_headers.items(): + if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): + continue + headers.update(values) + return headers diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 9d6fcd951..1fe91fd7f 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -12,29 +12,29 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import TavilySearchToolConfig -class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: TavilySearchToolConfig): self.config = config async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index a3724e4b4..6e1d0f61d 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -12,19 +12,19 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import WolframAlphaToolConfig -class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: WolframAlphaToolConfig): self.config = config self.url = "https://api.wolframalpha.com/v2/query" @@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index a919963ab..a59a38573 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -84,6 +84,14 @@ class ChromaIndex(EmbeddingIndex): async def delete(self): await maybe_await(self.client.delete_collection(self.collection.name)) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Chroma") + class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index c98417b56..6628292db 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -73,7 +73,7 @@ class MilvusIndex(EmbeddingIndex): logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") raise e - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: search_res = await asyncio.to_thread( self.client.search, collection_name=self.collection_name, @@ -86,6 +86,14 @@ class MilvusIndex(EmbeddingIndex): scores = [res["distance"] for res in search_res[0]] return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Milvus") + class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 94546c6cf..ea918c552 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: execute_values(cur, query, values, template="(%s, %s, %s::vector)") - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute( f""" @@ -120,6 +120,14 @@ class PGVectorIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in PGVector") + async def delete(self): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 514a6c70d..ff0690083 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -68,7 +68,7 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = ( await self.client.query_points( collection_name=self.collection_name, @@ -95,6 +95,14 @@ class QdrantIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Qdrant") + async def delete(self): await self.client.delete_collection(collection_name=self.collection_name) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 308d2eb3d..e6fe8ccd3 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -55,7 +55,7 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: collection = self.client.collections.get(self.collection_name) results = collection.query.near_vector( @@ -84,6 +84,14 @@ class WeaviateIndex(EmbeddingIndex): collection = self.client.collections.get(self.collection_name) collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids)) + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Keyword search is not supported in Weaviate") + class WeaviateVectorIOAdapter( VectorIO, diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py new file mode 100644 index 000000000..7b6bc2e3d --- /dev/null +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.inference import ( + ListOpenAIChatCompletionResponse, + OpenAIChatCompletion, + OpenAICompletionWithInputMessages, + OpenAIMessageParam, + Order, +) +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + +from ..sqlstore.api import ColumnDefinition, ColumnType +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl + + +class InferenceStore: + def __init__(self, sql_store_config: SqlStoreConfig): + if not sql_store_config: + sql_store_config = SqliteSqlStoreConfig( + db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), + ) + self.sql_store_config = sql_store_config + self.sql_store = None + + async def initialize(self): + """Create the necessary tables if they don't exist.""" + self.sql_store = sqlstore_impl(self.sql_store_config) + await self.sql_store.create_table( + "chat_completions", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "created": ColumnType.INTEGER, + "model": ColumnType.STRING, + "choices": ColumnType.JSON, + "input_messages": ColumnType.JSON, + }, + ) + + async def store_chat_completion( + self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam] + ) -> None: + if not self.sql_store: + raise ValueError("Inference store is not initialized") + + data = chat_completion.model_dump() + + await self.sql_store.insert( + "chat_completions", + { + "id": data["id"], + "created": data["created"], + "model": data["model"], + "choices": data["choices"], + "input_messages": [message.model_dump() for message in input_messages], + }, + ) + + async def list_chat_completions( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIChatCompletionResponse: + """ + List chat completions from the database. + + :param after: The ID of the last chat completion to return. + :param limit: The maximum number of chat completions to return. + :param model: The model to filter by. + :param order: The order to sort the chat completions by. + """ + if not self.sql_store: + raise ValueError("Inference store is not initialized") + + # TODO: support after + if after: + raise NotImplementedError("After is not supported for SQLite") + if not order: + order = Order.desc + + rows = await self.sql_store.fetch_all( + "chat_completions", + where={"model": model} if model else None, + order_by=[("created", order.value)], + limit=limit, + ) + + data = [ + OpenAICompletionWithInputMessages( + id=row["id"], + created=row["created"], + model=row["model"], + choices=row["choices"], + input_messages=row["input_messages"], + ) + for row in rows + ] + return ListOpenAIChatCompletionResponse( + data=data, + # TODO: implement has_more + has_more=False, + first_id=data[0].id if data else "", + last_id=data[-1].id if data else "", + ) + + async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: + if not self.sql_store: + raise ValueError("Inference store is not initialized") + + row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id}) + if not row: + raise ValueError(f"Chat completion with id {completion_id} not found") from None + return OpenAICompletionWithInputMessages( + id=row["id"], + created=row["created"], + model=row["model"], + choices=row["choices"], + input_messages=row["input_messages"], + ) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index cc0000528..049f06fdb 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin: outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], ): id = f"chatcmpl-{uuid.uuid4()}" - for outstanding_response in outstanding_responses: + for i, outstanding_response in enumerate(outstanding_responses): response = await outstanding_response - i = 0 async for chunk in response: event = chunk.event finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) @@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin: model=model, object="chat.completion.chunk", ) - i = i + 1 async def _process_non_stream_response( self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] diff --git a/llama_stack/providers/utils/inference/stream_utils.py b/llama_stack/providers/utils/inference/stream_utils.py new file mode 100644 index 000000000..a2edbb9c8 --- /dev/null +++ b/llama_stack/providers/utils/inference/stream_utils.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import AsyncIterator +from datetime import datetime, timezone +from typing import Any + +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoice, + OpenAIChoiceLogprobs, + OpenAIMessageParam, +) +from llama_stack.providers.utils.inference.inference_store import InferenceStore + + +async def stream_and_store_openai_completion( + provider_stream: AsyncIterator[OpenAIChatCompletionChunk], + model: str, + store: InferenceStore, + input_messages: list[OpenAIMessageParam], +) -> AsyncIterator[OpenAIChatCompletionChunk]: + """ + Wraps a provider's stream, yields chunks, and stores the full completion at the end. + """ + id = None + created = None + choices_data: dict[int, dict[str, Any]] = {} + + try: + async for chunk in provider_stream: + if id is None and chunk.id: + id = chunk.id + if created is None and chunk.created: + created = chunk.created + + if chunk.choices: + for choice_delta in chunk.choices: + idx = choice_delta.index + if idx not in choices_data: + choices_data[idx] = { + "content_parts": [], + "tool_calls_builder": {}, + "finish_reason": None, + "logprobs_content_parts": [], + } + current_choice_data = choices_data[idx] + + if choice_delta.delta: + delta = choice_delta.delta + if delta.content: + current_choice_data["content_parts"].append(delta.content) + if delta.tool_calls: + for tool_call_delta in delta.tool_calls: + tc_idx = tool_call_delta.index + if tc_idx not in current_choice_data["tool_calls_builder"]: + # Initialize with correct structure for _ToolCallBuilderData + current_choice_data["tool_calls_builder"][tc_idx] = { + "id": None, + "type": "function", + "function_name_parts": [], + "function_arguments_parts": [], + } + builder = current_choice_data["tool_calls_builder"][tc_idx] + if tool_call_delta.id: + builder["id"] = tool_call_delta.id + if tool_call_delta.type: + builder["type"] = tool_call_delta.type + if tool_call_delta.function: + if tool_call_delta.function.name: + builder["function_name_parts"].append(tool_call_delta.function.name) + if tool_call_delta.function.arguments: + builder["function_arguments_parts"].append(tool_call_delta.function.arguments) + if choice_delta.finish_reason: + current_choice_data["finish_reason"] = choice_delta.finish_reason + if choice_delta.logprobs and choice_delta.logprobs.content: + # Ensure that we are extending with the correct type + current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) + yield chunk + finally: + if id: + assembled_choices: list[OpenAIChoice] = [] + for choice_idx, choice_data in choices_data.items(): + content_str = "".join(choice_data["content_parts"]) + assembled_tool_calls: list[OpenAIChatCompletionToolCall] = [] + if choice_data["tool_calls_builder"]: + for tc_build_data in choice_data["tool_calls_builder"].values(): + if tc_build_data["id"]: + func_name = "".join(tc_build_data["function_name_parts"]) + func_args = "".join(tc_build_data["function_arguments_parts"]) + assembled_tool_calls.append( + OpenAIChatCompletionToolCall( + id=tc_build_data["id"], + type=tc_build_data["type"], # No or "function" needed, already set + function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args), + ) + ) + message = OpenAIAssistantMessageParam( + role="assistant", + content=content_str if content_str else None, + tool_calls=assembled_tool_calls if assembled_tool_calls else None, + ) + logprobs_content = choice_data["logprobs_content_parts"] + final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None + + assembled_choices.append( + OpenAIChoice( + finish_reason=choice_data["finish_reason"], + index=choice_idx, + message=message, + logprobs=final_logprobs, + ) + ) + + final_response = OpenAIChatCompletion( + id=id, + choices=assembled_choices, + created=created or int(datetime.now(timezone.utc).timestamp()), + model=model, + object="chat.completion", + ) + await store.store_chat_completion(final_response, input_messages) diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index e0e9d0679..3655c7049 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -177,7 +177,11 @@ class EmbeddingIndex(ABC): raise NotImplementedError() @abstractmethod - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + raise NotImplementedError() + + @abstractmethod + async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError() @abstractmethod @@ -210,9 +214,12 @@ class VectorDBWithIndex: if params is None: params = {} k = params.get("max_chunks", 3) + mode = params.get("mode") score_threshold = params.get("score_threshold", 0.0) - - query_str = interleaved_content_as_str(query) - embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_str]) - query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) - return await self.index.query(query_vector, k, score_threshold) + query_string = interleaved_content_as_str(query) + if mode == "keyword": + return await self.index.query_keyword(query_string, k, score_threshold) + else: + embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) + query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) + return await self.index.query_vector(query_vector, k, score_threshold) diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py new file mode 100644 index 000000000..15354e3e2 --- /dev/null +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.agents import ( + Order, +) +from llama_stack.apis.agents.openai_responses import ( + ListOpenAIResponseInputItem, + ListOpenAIResponseObject, + OpenAIResponseInput, + OpenAIResponseObject, + OpenAIResponseObjectWithInput, +) +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + +from ..sqlstore.api import ColumnDefinition, ColumnType +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl + + +class ResponsesStore: + def __init__(self, sql_store_config: SqlStoreConfig): + if not sql_store_config: + sql_store_config = SqliteSqlStoreConfig( + db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), + ) + self.sql_store = sqlstore_impl(sql_store_config) + + async def initialize(self): + """Create the necessary tables if they don't exist.""" + await self.sql_store.create_table( + "openai_responses", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "created_at": ColumnType.INTEGER, + "response_object": ColumnType.JSON, + "model": ColumnType.STRING, + }, + ) + + async def store_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + data = response_object.model_dump() + data["input"] = [input_item.model_dump() for input_item in input] + + await self.sql_store.insert( + "openai_responses", + { + "id": data["id"], + "created_at": data["created_at"], + "model": data["model"], + "response_object": data, + }, + ) + + async def list_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + """ + List responses from the database. + + :param after: The ID of the last response to return. + :param limit: The maximum number of responses to return. + :param model: The model to filter by. + :param order: The order to sort the responses by. + """ + # TODO: support after + if after: + raise NotImplementedError("After is not supported for SQLite") + if not order: + order = Order.desc + + rows = await self.sql_store.fetch_all( + "openai_responses", + where={"model": model} if model else None, + order_by=[("created_at", order.value)], + limit=limit, + ) + + data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in rows] + return ListOpenAIResponseObject( + data=data, + # TODO: implement has_more + has_more=False, + first_id=data[0].id if data else "", + last_id=data[-1].id if data else "", + ) + + async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput: + row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) + if not row: + raise ValueError(f"Response with id {response_id} not found") from None + return OpenAIResponseObjectWithInput(**row["response_object"]) + + async def list_response_input_items( + self, + response_id: str, + after: str | None = None, + before: str | None = None, + include: list[str] | None = None, + limit: int | None = 20, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseInputItem: + """ + List input items for a given response. + + :param response_id: The ID of the response to retrieve input items for. + :param after: An item ID to list items after, used for pagination. + :param before: An item ID to list items before, used for pagination. + :param include: Additional fields to include in the response. + :param limit: A limit on the number of objects to be returned. + :param order: The order to return the input items in. + """ + # TODO: support after/before pagination + if after or before: + raise NotImplementedError("After/before pagination is not supported yet") + if include: + raise NotImplementedError("Include is not supported yet") + + response_with_input = await self.get_response_object(response_id) + input_items = response_with_input.input + + if order == Order.desc: + input_items = list(reversed(input_items)) + + if limit is not None and len(input_items) > limit: + input_items = input_items[:limit] + + return ListOpenAIResponseInputItem(data=input_items) diff --git a/llama_stack/providers/utils/sqlstore/api.py b/llama_stack/providers/utils/sqlstore/api.py new file mode 100644 index 000000000..ace40e4c4 --- /dev/null +++ b/llama_stack/providers/utils/sqlstore/api.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import Mapping +from enum import Enum +from typing import Any, Literal, Protocol + +from pydantic import BaseModel + + +class ColumnType(Enum): + INTEGER = "INTEGER" + STRING = "STRING" + TEXT = "TEXT" + FLOAT = "FLOAT" + BOOLEAN = "BOOLEAN" + JSON = "JSON" + DATETIME = "DATETIME" + + +class ColumnDefinition(BaseModel): + type: ColumnType + primary_key: bool = False + nullable: bool = True + default: Any = None + + +class SqlStore(Protocol): + """ + A protocol for a SQL store. + """ + + async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: + """ + Create a table. + """ + pass + + async def insert(self, table: str, data: Mapping[str, Any]) -> None: + """ + Insert a row into a table. + """ + pass + + async def fetch_all( + self, + table: str, + where: Mapping[str, Any] | None = None, + limit: int | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> list[dict[str, Any]]: + """ + Fetch all rows from a table. + """ + pass + + async def fetch_one( + self, + table: str, + where: Mapping[str, Any] | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> dict[str, Any] | None: + """ + Fetch one row from a table. + """ + pass + + async def update( + self, + table: str, + data: Mapping[str, Any], + where: Mapping[str, Any], + ) -> None: + """ + Update a row in a table. + """ + pass + + async def delete( + self, + table: str, + where: Mapping[str, Any], + ) -> None: + """ + Delete a row from a table. + """ + pass diff --git a/llama_stack/providers/utils/sqlstore/sqlite/sqlite.py b/llama_stack/providers/utils/sqlstore/sqlite/sqlite.py new file mode 100644 index 000000000..0ef5f0fa1 --- /dev/null +++ b/llama_stack/providers/utils/sqlstore/sqlite/sqlite.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from collections.abc import Mapping +from typing import Any, Literal + +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + Float, + Integer, + MetaData, + String, + Table, + Text, + select, +) +from sqlalchemy.ext.asyncio import create_async_engine + +from ..api import ColumnDefinition, ColumnType, SqlStore +from ..sqlstore import SqliteSqlStoreConfig + +TYPE_MAPPING: dict[ColumnType, Any] = { + ColumnType.INTEGER: Integer, + ColumnType.STRING: String, + ColumnType.FLOAT: Float, + ColumnType.BOOLEAN: Boolean, + ColumnType.DATETIME: DateTime, + ColumnType.TEXT: Text, + ColumnType.JSON: JSON, +} + + +class SqliteSqlStoreImpl(SqlStore): + def __init__(self, config: SqliteSqlStoreConfig): + self.engine = create_async_engine(config.engine_str) + self.metadata = MetaData() + + async def create_table( + self, + table: str, + schema: Mapping[str, ColumnType | ColumnDefinition], + ) -> None: + if not schema: + raise ValueError(f"No columns defined for table '{table}'.") + + sqlalchemy_columns: list[Column] = [] + + for col_name, col_props in schema.items(): + col_type = None + is_primary_key = False + is_nullable = True # Default to nullable + + if isinstance(col_props, ColumnType): + col_type = col_props + elif isinstance(col_props, ColumnDefinition): + col_type = col_props.type + is_primary_key = col_props.primary_key + is_nullable = col_props.nullable + + sqlalchemy_type = TYPE_MAPPING.get(col_type) + if not sqlalchemy_type: + raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.") + + sqlalchemy_columns.append( + Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable) + ) + + # Check if table already exists in metadata, otherwise define it + if table not in self.metadata.tables: + sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns) + else: + sqlalchemy_table = self.metadata.tables[table] + + # Create the table in the database if it doesn't exist + # checkfirst=True ensures it doesn't try to recreate if it's already there + async with self.engine.begin() as conn: + await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) + + async def insert(self, table: str, data: Mapping[str, Any]) -> None: + async with self.engine.begin() as conn: + await conn.execute(self.metadata.tables[table].insert(), data) + await conn.commit() + + async def fetch_all( + self, + table: str, + where: Mapping[str, Any] | None = None, + limit: int | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> list[dict[str, Any]]: + async with self.engine.begin() as conn: + query = select(self.metadata.tables[table]) + if where: + for key, value in where.items(): + query = query.where(self.metadata.tables[table].c[key] == value) + if limit: + query = query.limit(limit) + if order_by: + if not isinstance(order_by, list): + raise ValueError( + f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" + ) + for order in order_by: + if not isinstance(order, tuple): + raise ValueError( + f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" + ) + name, order_type = order + if order_type == "asc": + query = query.order_by(self.metadata.tables[table].c[name].asc()) + elif order_type == "desc": + query = query.order_by(self.metadata.tables[table].c[name].desc()) + else: + raise ValueError(f"Invalid order '{order_type}' for column '{name}'") + result = await conn.execute(query) + if result.rowcount == 0: + return [] + return [dict(row._mapping) for row in result] + + async def fetch_one( + self, + table: str, + where: Mapping[str, Any] | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> dict[str, Any] | None: + rows = await self.fetch_all(table, where, limit=1, order_by=order_by) + if not rows: + return None + return rows[0] + + async def update( + self, + table: str, + data: Mapping[str, Any], + where: Mapping[str, Any], + ) -> None: + if not where: + raise ValueError("where is required for update") + + async with self.engine.begin() as conn: + stmt = self.metadata.tables[table].update() + for key, value in where.items(): + stmt = stmt.where(self.metadata.tables[table].c[key] == value) + await conn.execute(stmt, data) + await conn.commit() + + async def delete(self, table: str, where: Mapping[str, Any]) -> None: + if not where: + raise ValueError("where is required for delete") + + async with self.engine.begin() as conn: + stmt = self.metadata.tables[table].delete() + for key, value in where.items(): + stmt = stmt.where(self.metadata.tables[table].c[key] == value) + await conn.execute(stmt) + await conn.commit() diff --git a/llama_stack/providers/utils/sqlstore/sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlstore.py new file mode 100644 index 000000000..99f64805f --- /dev/null +++ b/llama_stack/providers/utils/sqlstore/sqlstore.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from enum import Enum +from pathlib import Path +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + +from .api import SqlStore + + +class SqlStoreType(Enum): + sqlite = "sqlite" + postgres = "postgres" + + +class SqliteSqlStoreConfig(BaseModel): + type: Literal["sqlite"] = SqlStoreType.sqlite.value + db_path: str = Field( + default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), + description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db", + ) + + @property + def engine_str(self) -> str: + return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix() + + @classmethod + def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"): + return cls( + type="sqlite", + db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name, + ) + + # TODO: move this when we have a better way to specify dependencies with internal APIs + @property + def pip_packages(self) -> list[str]: + return ["sqlalchemy[asyncio]"] + + +class PostgresSqlStoreConfig(BaseModel): + type: Literal["postgres"] = SqlStoreType.postgres.value + + @property + def pip_packages(self) -> list[str]: + raise NotImplementedError("Postgres is not implemented yet") + + +SqlStoreConfig = Annotated[ + SqliteSqlStoreConfig | PostgresSqlStoreConfig, + Field(discriminator="type", default=SqlStoreType.sqlite.value), +] + + +def sqlstore_impl(config: SqlStoreConfig) -> SqlStore: + if config.type == SqlStoreType.sqlite.value: + from .sqlite.sqlite import SqliteSqlStoreImpl + + impl = SqliteSqlStoreImpl(config) + elif config.type == SqlStoreType.postgres.value: + raise NotImplementedError("Postgres is not implemented yet") + else: + raise ValueError(f"Unknown sqlstore type {config.type}") + + return impl diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 0f4fdd0d8..4edfa6516 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -34,6 +34,8 @@ logger = get_logger(__name__, category="core") INVALID_SPAN_ID = 0x0000000000000000 INVALID_TRACE_ID = 0x00000000000000000000000000000000 +ROOT_SPAN_MARKERS = ["__root__", "__root_span__"] + def trace_id_to_str(trace_id: int) -> str: """Convenience trace ID formatting method @@ -178,7 +180,8 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont trace_id = generate_trace_id() context = TraceContext(BACKGROUND_LOGGER, trace_id) - context.push_span(name, {"__root__": True, **(attributes or {})}) + attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {}) + context.push_span(name, attributes) CURRENT_TRACE_CONTEXT.set(context) return context diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py new file mode 100644 index 000000000..f024693a0 --- /dev/null +++ b/llama_stack/providers/utils/tools/mcp.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from contextlib import asynccontextmanager +from typing import Any + +try: + # for python < 3.11 + import exceptiongroup + + BaseExceptionGroup = exceptiongroup.BaseExceptionGroup +except ImportError: + pass + +import httpx +from mcp import ClientSession +from mcp import types as mcp_types +from mcp.client.sse import sse_client + +from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem +from llama_stack.apis.tools import ( + ListToolDefsResponse, + ToolDef, + ToolInvocationResult, + ToolParameter, +) +from llama_stack.distribution.datatypes import AuthenticationRequiredError +from llama_stack.log import get_logger + +logger = get_logger(__name__, category="tools") + + +@asynccontextmanager +async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): + try: + async with sse_client(endpoint, headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + except BaseException as e: + if isinstance(e, BaseExceptionGroup): + for exc in e.exceptions: + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401: + raise AuthenticationRequiredError(e) from e + + raise + + +async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: + tools = [] + async with sse_client_wrapper(endpoint, headers) as session: + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), + ) + ) + tools.append( + ToolDef( + name=tool.name, + description=tool.description, + parameters=parameters, + metadata={ + "endpoint": endpoint, + }, + ) + ) + return ListToolDefsResponse(data=tools) + + +async def invoke_mcp_tool( + endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] +) -> ToolInvocationResult: + async with sse_client_wrapper(endpoint, headers) as session: + result = await session.call_tool(tool_name, kwargs) + + content: list[InterleavedContentItem] = [] + for item in result.content: + if isinstance(item, mcp_types.TextContent): + content.append(TextContentItem(text=item.text)) + elif isinstance(item, mcp_types.ImageContent): + content.append(ImageContentItem(image=item.data)) + elif isinstance(item, mcp_types.EmbeddedResource): + logger.warning(f"EmbeddedResource is not supported: {item}") + else: + raise ValueError(f"Unknown content type: {type(item)}") + return ToolInvocationResult( + content=content, + error_code=1 if result.isError else 0, + ) diff --git a/llama_stack/templates/bedrock/build.yaml b/llama_stack/templates/bedrock/build.yaml index 46d5b9c69..09fbf307d 100644 --- a/llama_stack/templates/bedrock/build.yaml +++ b/llama_stack/templates/bedrock/build.yaml @@ -29,3 +29,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/bedrock/run.yaml b/llama_stack/templates/bedrock/run.yaml index 30599a6c0..a58068a60 100644 --- a/llama_stack/templates/bedrock/run.yaml +++ b/llama_stack/templates/bedrock/run.yaml @@ -35,6 +35,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -96,6 +99,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/inference_store.db models: - metadata: {} model_id: meta.llama3-1-8b-instruct-v1:0 diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml index 0498da1cd..95b0302f2 100644 --- a/llama_stack/templates/cerebras/build.yaml +++ b/llama_stack/templates/cerebras/build.yaml @@ -29,3 +29,5 @@ distribution_spec: - remote::tavily-search - inline::rag-runtime image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml index 0731b1df9..c080536b7 100644 --- a/llama_stack/templates/cerebras/run.yaml +++ b/llama_stack/templates/cerebras/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference @@ -99,6 +102,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/inference_store.db models: - metadata: {} model_id: llama3.1-8b diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml index a4c5893c4..6fe96c603 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/templates/ci-tests/build.yaml @@ -30,3 +30,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index d9ee5b3cf..368187d3a 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -38,6 +38,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -99,6 +102,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/inference_store.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/dell/build.yaml b/llama_stack/templates/dell/build.yaml index f5beb6c2f..d37215f35 100644 --- a/llama_stack/templates/dell/build.yaml +++ b/llama_stack/templates/dell/build.yaml @@ -30,3 +30,6 @@ distribution_spec: - remote::tavily-search - inline::rag-runtime image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/dell/run-with-safety.yaml b/llama_stack/templates/dell/run-with-safety.yaml index 24c515112..5c6072245 100644 --- a/llama_stack/templates/dell/run-with-safety.yaml +++ b/llama_stack/templates/dell/run-with-safety.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -99,6 +102,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/dell/run.yaml b/llama_stack/templates/dell/run.yaml index fdece894f..ffaa0bf2f 100644 --- a/llama_stack/templates/dell/run.yaml +++ b/llama_stack/templates/dell/run.yaml @@ -37,6 +37,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -95,6 +98,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json deleted file mode 100644 index fb4ab9fda..000000000 --- a/llama_stack/templates/dependencies.json +++ /dev/null @@ -1,843 +0,0 @@ -{ - "bedrock": [ - "aiosqlite", - "autoevals", - "blobfile", - "boto3", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn" - ], - "cerebras": [ - "aiosqlite", - "autoevals", - "blobfile", - "cerebras_cloud_sdk", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "langdetect", - "matplotlib", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "ci-tests": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "fastapi", - "fire", - "fireworks-ai", - "httpx", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "sqlite-vec", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "dell": [ - "aiohttp", - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "huggingface_hub", - "langdetect", - "matplotlib", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "fireworks": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "fireworks-ai", - "httpx", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "groq": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "langdetect", - "litellm", - "matplotlib", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn" - ], - "hf-endpoint": [ - "aiohttp", - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "huggingface_hub", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn" - ], - "hf-serverless": [ - "aiohttp", - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "huggingface_hub", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "llama_api": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "fastapi", - "fire", - "httpx", - "langdetect", - "litellm", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "sqlite-vec", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "meta-reference-gpu": [ - "accelerate", - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "fairscale", - "faiss-cpu", - "fastapi", - "fbgemm-gpu-genai==1.1.2", - "fire", - "httpx", - "langdetect", - "lm-format-enforcer", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentence-transformers", - "sentencepiece", - "torch", - "torchao==0.8.0", - "torchvision", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "zmq" - ], - "nvidia": [ - "aiohttp", - "aiosqlite", - "blobfile", - "chardet", - "datasets", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn" - ], - "ollama": [ - "aiohttp", - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "ollama", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "peft", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "torch", - "tqdm", - "transformers", - "tree_sitter", - "trl", - "uvicorn" - ], - "open-benchmark": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "fastapi", - "fire", - "httpx", - "langdetect", - "litellm", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "sqlite-vec", - "together", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn" - ], - "passthrough": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "remote-vllm": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "sambanova": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "litellm", - "matplotlib", - "mcp", - "nltk", - "numpy", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "starter": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "fastapi", - "fire", - "fireworks-ai", - "httpx", - "langdetect", - "litellm", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "sqlite-vec", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "tgi": [ - "aiohttp", - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "huggingface_hub", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "together": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "together", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "verification": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "fastapi", - "fire", - "httpx", - "langdetect", - "litellm", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "sqlite-vec", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "vllm-gpu": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "vllm", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ], - "watsonx": [ - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "datasets", - "emoji", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "ibm_watson_machine_learning", - "langdetect", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" - ] -} diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index 7c74157ee..f162d9b43 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -31,3 +31,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/fireworks/run-with-safety.yaml b/llama_stack/templates/fireworks/run-with-safety.yaml index 0ab07613e..41500f6f6 100644 --- a/llama_stack/templates/fireworks/run-with-safety.yaml +++ b/llama_stack/templates/fireworks/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -111,6 +114,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/inference_store.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index 81c293a46..b1fa03306 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -106,6 +109,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/inference_store.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/groq/build.yaml b/llama_stack/templates/groq/build.yaml index 800c3e3ae..92b46ce66 100644 --- a/llama_stack/templates/groq/build.yaml +++ b/llama_stack/templates/groq/build.yaml @@ -26,3 +26,5 @@ distribution_spec: - remote::tavily-search - inline::rag-runtime image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/groq/run.yaml b/llama_stack/templates/groq/run.yaml index 79c350c73..db7ebffee 100644 --- a/llama_stack/templates/groq/run.yaml +++ b/llama_stack/templates/groq/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -99,6 +102,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/inference_store.db models: - metadata: {} model_id: groq/llama3-8b-8192 diff --git a/llama_stack/templates/hf-endpoint/build.yaml b/llama_stack/templates/hf-endpoint/build.yaml index 2a40c3909..4d09cc33e 100644 --- a/llama_stack/templates/hf-endpoint/build.yaml +++ b/llama_stack/templates/hf-endpoint/build.yaml @@ -29,3 +29,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/hf-endpoint/run-with-safety.yaml b/llama_stack/templates/hf-endpoint/run-with-safety.yaml index 82bcaa3cf..15cf2a47f 100644 --- a/llama_stack/templates/hf-endpoint/run-with-safety.yaml +++ b/llama_stack/templates/hf-endpoint/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -107,6 +110,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-endpoint/run.yaml b/llama_stack/templates/hf-endpoint/run.yaml index ec7c55032..428edf9a2 100644 --- a/llama_stack/templates/hf-endpoint/run.yaml +++ b/llama_stack/templates/hf-endpoint/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -102,6 +105,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-serverless/build.yaml b/llama_stack/templates/hf-serverless/build.yaml index f77f8773b..d06c628ac 100644 --- a/llama_stack/templates/hf-serverless/build.yaml +++ b/llama_stack/templates/hf-serverless/build.yaml @@ -30,3 +30,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/hf-serverless/run-with-safety.yaml b/llama_stack/templates/hf-serverless/run-with-safety.yaml index 320976e2c..ab461c6c3 100644 --- a/llama_stack/templates/hf-serverless/run-with-safety.yaml +++ b/llama_stack/templates/hf-serverless/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -107,6 +110,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-serverless/run.yaml b/llama_stack/templates/hf-serverless/run.yaml index 2b22b20c6..d238506fb 100644 --- a/llama_stack/templates/hf-serverless/run.yaml +++ b/llama_stack/templates/hf-serverless/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -102,6 +105,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/llama_api/build.yaml b/llama_stack/templates/llama_api/build.yaml index f97ee4091..d0dc08923 100644 --- a/llama_stack/templates/llama_api/build.yaml +++ b/llama_stack/templates/llama_api/build.yaml @@ -30,3 +30,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/llama_api/run.yaml b/llama_stack/templates/llama_api/run.yaml index a879482d7..a7f2b0769 100644 --- a/llama_stack/templates/llama_api/run.yaml +++ b/llama_stack/templates/llama_api/run.yaml @@ -50,6 +50,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -111,6 +114,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/inference_store.db models: - metadata: {} model_id: Llama-3.3-70B-Instruct diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index a9d03490b..e0ac87e47 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -29,3 +29,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 180d44e0f..2b751a514 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -56,6 +56,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -117,6 +120,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index d879667e0..a24c5fec5 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -107,6 +110,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index a05cf97ad..e1e6fb3d8 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -24,3 +24,6 @@ distribution_spec: tool_runtime: - inline::rag-runtime image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 0c088d8f8..eeccc006a 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -92,6 +95,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index eb272e706..cd36ec362 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -80,6 +83,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/inference_store.db models: - metadata: {} model_id: meta/llama3-8b-instruct diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 7d5363575..9d8ba3a1e 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -32,3 +32,6 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 74d0822ca..d63c5e366 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -40,6 +40,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -112,6 +115,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 71229be97..d208cd7f0 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -38,6 +38,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -110,6 +113,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/open-benchmark/build.yaml b/llama_stack/templates/open-benchmark/build.yaml index b14e96435..aa6d876fe 100644 --- a/llama_stack/templates/open-benchmark/build.yaml +++ b/llama_stack/templates/open-benchmark/build.yaml @@ -33,3 +33,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 30a27cbd8..0e5edf728 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -64,6 +64,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -125,6 +128,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/inference_store.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/passthrough/build.yaml b/llama_stack/templates/passthrough/build.yaml index f8d099070..7560f1032 100644 --- a/llama_stack/templates/passthrough/build.yaml +++ b/llama_stack/templates/passthrough/build.yaml @@ -31,3 +31,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/passthrough/run-with-safety.yaml b/llama_stack/templates/passthrough/run-with-safety.yaml index a91b9fc92..bbf5d9a52 100644 --- a/llama_stack/templates/passthrough/run-with-safety.yaml +++ b/llama_stack/templates/passthrough/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -111,6 +114,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/inference_store.db models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct diff --git a/llama_stack/templates/passthrough/run.yaml b/llama_stack/templates/passthrough/run.yaml index d1dd3b885..146906d9b 100644 --- a/llama_stack/templates/passthrough/run.yaml +++ b/llama_stack/templates/passthrough/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -106,6 +109,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/inference_store.db models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index 4baaaf9c8..fcd4deeff 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -31,3 +31,6 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index 6931d4ba9..e83162a4f 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -50,6 +50,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference @@ -115,6 +118,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 05671165d..4cdf88c6b 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -43,6 +43,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference @@ -108,6 +111,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index 81d90f420..b644dcfdc 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -1,6 +1,6 @@ version: '2' distribution_spec: - description: Use SambaNova for running LLM inference + description: Use SambaNova for running LLM inference and safety providers: inference: - remote::sambanova @@ -10,7 +10,7 @@ distribution_spec: - remote::chromadb - remote::pgvector safety: - - inline::llama-guard + - remote::sambanova agents: - inline::meta-reference telemetry: @@ -22,3 +22,5 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/sambanova/doc_template.md b/llama_stack/templates/sambanova/doc_template.md index 42d9efb66..1dc76fd3f 100644 --- a/llama_stack/templates/sambanova/doc_template.md +++ b/llama_stack/templates/sambanova/doc_template.md @@ -37,33 +37,44 @@ The following models are available by default: ### Prerequisite: API Keys -Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/). +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup). ## Running Llama Stack with SambaNova You can do this via Conda (build code) or Docker which has a pre-built image. -### Via Docker -This method allows you to get started quickly without having to build the distribution code. +### Via Docker ```bash LLAMA_STACK_PORT=8321 +llama stack build --template sambanova --image-type container docker run \ -it \ - --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - llamastack/distribution-{{ name }} \ + -v ~/.llama:/root/.llama \ + distribution-{{ name }} \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` + +### Via Venv + +```bash +llama stack build --template sambanova --image-type venv +llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` + + ### Via Conda ```bash llama stack build --template sambanova --image-type conda -llama stack run ./run.yaml \ +llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 620d50307..8c2a933ab 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -38,10 +38,11 @@ providers: user: ${env.PGVECTOR_USER:} password: ${env.PGVECTOR_PASSWORD:} safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_id: sambanova + provider_type: remote::sambanova config: - excluded_categories: [] + url: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY} agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -50,6 +51,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -81,6 +85,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/inference_store.db models: - metadata: {} model_id: sambanova/Meta-Llama-3.1-8B-Instruct @@ -189,6 +196,9 @@ models: model_type: embedding shields: - shield_id: meta-llama/Llama-Guard-3-8B + provider_shield_id: sambanova/Meta-Llama-Guard-3-8B +- shield_id: sambanova/Meta-Llama-Guard-3-8B + provider_shield_id: sambanova/Meta-Llama-Guard-3-8B vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 2f8a0b08a..54a49423d 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -34,7 +34,7 @@ def get_distribution_template() -> DistributionTemplate: providers = { "inference": ["remote::sambanova", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], + "safety": ["remote::sambanova"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], "tool_runtime": [ @@ -110,7 +110,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name=name, distro_type="self_hosted", - description="Use SambaNova for running LLM inference", + description="Use SambaNova for running LLM inference and safety", container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, @@ -122,7 +122,15 @@ def get_distribution_template() -> DistributionTemplate: "vector_io": vector_io_providers, }, default_models=default_models + [embedding_model], - default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + default_shields=[ + ShieldInput( + shield_id="meta-llama/Llama-Guard-3-8B", provider_shield_id="sambanova/Meta-Llama-Guard-3-8B" + ), + ShieldInput( + shield_id="sambanova/Meta-Llama-Guard-3-8B", + provider_shield_id="sambanova/Meta-Llama-Guard-3-8B", + ), + ], default_tool_groups=default_tool_groups, ), }, diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/templates/starter/build.yaml index 35bd0c713..652814ffd 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/templates/starter/build.yaml @@ -35,3 +35,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 402695850..04425ed35 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -72,6 +72,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -133,6 +136,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/inference_store.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index e4d28d904..ec5cd38ea 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -29,6 +29,7 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig def get_model_registry( @@ -117,6 +118,10 @@ class RunConfigSettings(BaseModel): __distro_dir__=f"~/.llama/distributions/{name}", db_name="registry.db", ), + inference_store=SqliteSqlStoreConfig.sample_run_config( + __distro_dir__=f"~/.llama/distributions/{name}", + db_name="inference_store.db", + ), models=self.default_models or [], shields=self.default_shields or [], tool_groups=self.default_tool_groups or [], @@ -146,14 +151,20 @@ class DistributionTemplate(BaseModel): available_models_by_provider: dict[str, list[ProviderModelEntry]] | None = None def build_config(self) -> BuildConfig: + additional_pip_packages: list[str] = [] + for run_config in self.run_configs.values(): + run_config_ = run_config.run_config(self.name, self.providers, self.container_image) + if run_config_.inference_store: + additional_pip_packages.extend(run_config_.inference_store.pip_packages) + return BuildConfig( - name=self.name, distribution_spec=DistributionSpec( description=self.description, container_image=self.container_image, providers=self.providers, ), image_type="conda", # default to conda, can be overridden + additional_pip_packages=additional_pip_packages, ) def generate_markdown_docs(self) -> str: diff --git a/llama_stack/templates/tgi/build.yaml b/llama_stack/templates/tgi/build.yaml index d2ba1c3e9..652900c84 100644 --- a/llama_stack/templates/tgi/build.yaml +++ b/llama_stack/templates/tgi/build.yaml @@ -30,3 +30,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/tgi/run-with-safety.yaml b/llama_stack/templates/tgi/run-with-safety.yaml index 3255e9c0b..c797b93aa 100644 --- a/llama_stack/templates/tgi/run-with-safety.yaml +++ b/llama_stack/templates/tgi/run-with-safety.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -102,6 +105,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/tgi/run.yaml b/llama_stack/templates/tgi/run.yaml index 179087258..7e91d20bd 100644 --- a/llama_stack/templates/tgi/run.yaml +++ b/llama_stack/templates/tgi/run.yaml @@ -40,6 +40,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -101,6 +104,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index b7338795c..4a556a66f 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -31,3 +31,6 @@ distribution_spec: - remote::model-context-protocol - remote::wolfram-alpha image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index fe8c8e397..190a0400b 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -111,6 +114,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/inference_store.db models: - metadata: {} model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index b903fc659..ce9542130 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -106,6 +109,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/inference_store.db models: - metadata: {} model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo diff --git a/llama_stack/templates/verification/build.yaml b/llama_stack/templates/verification/build.yaml index aae24c3ca..cb7ab4798 100644 --- a/llama_stack/templates/verification/build.yaml +++ b/llama_stack/templates/verification/build.yaml @@ -35,3 +35,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/verification/run.yaml b/llama_stack/templates/verification/run.yaml index 11af41da9..58b3c576c 100644 --- a/llama_stack/templates/verification/run.yaml +++ b/llama_stack/templates/verification/run.yaml @@ -74,6 +74,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -135,6 +138,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/inference_store.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/vllm-gpu/build.yaml b/llama_stack/templates/vllm-gpu/build.yaml index 53e257f22..5a9d003cb 100644 --- a/llama_stack/templates/vllm-gpu/build.yaml +++ b/llama_stack/templates/vllm-gpu/build.yaml @@ -30,3 +30,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/vllm-gpu/run.yaml b/llama_stack/templates/vllm-gpu/run.yaml index 5d3482528..6937e2bac 100644 --- a/llama_stack/templates/vllm-gpu/run.yaml +++ b/llama_stack/templates/vllm-gpu/run.yaml @@ -45,6 +45,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -106,6 +109,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/templates/watsonx/build.yaml index 638b16029..87233fb26 100644 --- a/llama_stack/templates/watsonx/build.yaml +++ b/llama_stack/templates/watsonx/build.yaml @@ -28,3 +28,5 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol image_type: conda +additional_pip_packages: +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index 8de6a2b6c..e7222fd57 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -42,6 +42,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference @@ -103,6 +106,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/inference_store.db models: - metadata: {} model_id: meta-llama/llama-3-3-70b-instruct diff --git a/llama_stack/ui/.prettierignore b/llama_stack/ui/.prettierignore new file mode 100644 index 000000000..1b8ac8894 --- /dev/null +++ b/llama_stack/ui/.prettierignore @@ -0,0 +1,3 @@ +# Ignore artifacts: +build +coverage diff --git a/llama_stack/ui/.prettierrc b/llama_stack/ui/.prettierrc new file mode 100644 index 000000000..0967ef424 --- /dev/null +++ b/llama_stack/ui/.prettierrc @@ -0,0 +1 @@ +{} diff --git a/llama_stack/ui/README.md b/llama_stack/ui/README.md index e3e21bf0b..b6f803509 100644 --- a/llama_stack/ui/README.md +++ b/llama_stack/ui/README.md @@ -1,6 +1,5 @@ ## This is WIP. - We use shadcdn/ui [Shadcn UI](https://ui.shadcn.com/) for the UI components. ## Getting Started @@ -8,7 +7,7 @@ We use shadcdn/ui [Shadcn UI](https://ui.shadcn.com/) for the UI components. First, install dependencies: ```bash -npm install next react react-dom +npm install ``` Then, run the development server: @@ -23,4 +22,4 @@ pnpm dev bun dev ``` -Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. +Open [http://localhost:8322](http://localhost:8322) with your browser to see the result. diff --git a/llama_stack/ui/app/layout.tsx b/llama_stack/ui/app/layout.tsx index f029002dd..ed8a6cd5d 100644 --- a/llama_stack/ui/app/layout.tsx +++ b/llama_stack/ui/app/layout.tsx @@ -20,7 +20,7 @@ export const metadata: Metadata = { }; import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; -import { AppSidebar } from "@/components/app-sidebar"; +import { AppSidebar } from "@/components/layout/app-sidebar"; export default function Layout({ children }: { children: React.ReactNode }) { return ( diff --git a/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx new file mode 100644 index 000000000..f7c2580da --- /dev/null +++ b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx @@ -0,0 +1,62 @@ +"use client"; + +import { useEffect, useState } from "react"; +import { useParams } from "next/navigation"; +import LlamaStackClient from "llama-stack-client"; +import { ChatCompletion } from "@/lib/types"; +import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail"; + +export default function ChatCompletionDetailPage() { + const params = useParams(); + const id = params.id as string; + + const [completionDetail, setCompletionDetail] = + useState(null); + const [isLoading, setIsLoading] = useState(true); + const [error, setError] = useState(null); + + useEffect(() => { + if (!id) { + setError(new Error("Completion ID is missing.")); + setIsLoading(false); + return; + } + + const client = new LlamaStackClient({ + baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL, + }); + + const fetchCompletionDetail = async () => { + setIsLoading(true); + setError(null); + setCompletionDetail(null); + try { + const response = await client.chat.completions.retrieve(id); + setCompletionDetail(response as ChatCompletion); + } catch (err) { + console.error( + `Error fetching chat completion detail for ID ${id}:`, + err, + ); + setError( + err instanceof Error + ? err + : new Error("Failed to fetch completion detail"), + ); + } finally { + setIsLoading(false); + } + }; + + fetchCompletionDetail(); + }, [id]); + + return ( + + ); +} diff --git a/llama_stack/ui/app/logs/chat-completions/layout.tsx b/llama_stack/ui/app/logs/chat-completions/layout.tsx new file mode 100644 index 000000000..3dd8c1222 --- /dev/null +++ b/llama_stack/ui/app/logs/chat-completions/layout.tsx @@ -0,0 +1,45 @@ +"use client"; + +import React from "react"; +import { usePathname, useParams } from "next/navigation"; +import { + PageBreadcrumb, + BreadcrumbSegment, +} from "@/components/layout/page-breadcrumb"; +import { truncateText } from "@/lib/truncate-text"; + +export default function ChatCompletionsLayout({ + children, +}: { + children: React.ReactNode; +}) { + const pathname = usePathname(); + const params = useParams(); + + let segments: BreadcrumbSegment[] = []; + + // Default for /logs/chat-completions + if (pathname === "/logs/chat-completions") { + segments = [{ label: "Chat Completions" }]; + } + + // For /logs/chat-completions/[id] + const idParam = params?.id; + if (idParam && typeof idParam === "string") { + segments = [ + { label: "Chat Completions", href: "/logs/chat-completions" }, + { label: `Details (${truncateText(idParam, 20)})` }, + ]; + } + + return ( +
+ <> + {segments.length > 0 && ( + + )} + {children} + +
+ ); +} diff --git a/llama_stack/ui/app/logs/chat-completions/page.tsx b/llama_stack/ui/app/logs/chat-completions/page.tsx index 84cceb8b7..3de77a042 100644 --- a/llama_stack/ui/app/logs/chat-completions/page.tsx +++ b/llama_stack/ui/app/logs/chat-completions/page.tsx @@ -1,7 +1,54 @@ -export default function ChatCompletions() { +"use client"; + +import { useEffect, useState } from "react"; +import LlamaStackClient from "llama-stack-client"; +import { ChatCompletion } from "@/lib/types"; +import { ChatCompletionsTable } from "@/components/chat-completions/chat-completion-table"; + +export default function ChatCompletionsPage() { + const [completions, setCompletions] = useState([]); + const [isLoading, setIsLoading] = useState(true); + const [error, setError] = useState(null); + + useEffect(() => { + const client = new LlamaStackClient({ + baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL, + }); + const fetchCompletions = async () => { + setIsLoading(true); + setError(null); + try { + const response = await client.chat.completions.list(); + const data = Array.isArray(response) + ? response + : (response as any).data; + + if (Array.isArray(data)) { + setCompletions(data); + } else { + console.error("Unexpected response structure:", response); + setError(new Error("Unexpected response structure")); + setCompletions([]); + } + } catch (err) { + console.error("Error fetching chat completions:", err); + setError( + err instanceof Error ? err : new Error("Failed to fetch completions"), + ); + setCompletions([]); + } finally { + setIsLoading(false); + } + }; + + fetchCompletions(); + }, []); + return ( -
-

Under Construction

-
+ ); } diff --git a/llama_stack/ui/components/chat-completions/chat-completion-detail.test.tsx b/llama_stack/ui/components/chat-completions/chat-completion-detail.test.tsx new file mode 100644 index 000000000..33247ed26 --- /dev/null +++ b/llama_stack/ui/components/chat-completions/chat-completion-detail.test.tsx @@ -0,0 +1,193 @@ +import React from "react"; +import { render, screen } from "@testing-library/react"; +import "@testing-library/jest-dom"; +import { ChatCompletionDetailView } from "./chat-completion-detail"; +import { ChatCompletion } from "@/lib/types"; + +// Initial test file setup for ChatCompletionDetailView + +describe("ChatCompletionDetailView", () => { + test("renders skeleton UI when isLoading is true", () => { + const { container } = render( + , + ); + // Use the data-slot attribute for Skeletons + const skeletons = container.querySelectorAll('[data-slot="skeleton"]'); + expect(skeletons.length).toBeGreaterThan(0); + }); + + test("renders error message when error prop is provided", () => { + render( + , + ); + expect( + screen.getByText(/Error loading details for ID err-id: Network Error/), + ).toBeInTheDocument(); + }); + + test("renders default error message when error.message is empty", () => { + render( + , + ); + // Use regex to match the error message regardless of whitespace + expect( + screen.getByText(/Error loading details for ID\s*err-id\s*:/), + ).toBeInTheDocument(); + }); + + test("renders error message when error prop is an object without message", () => { + render( + , + ); + // Use regex to match the error message regardless of whitespace + expect( + screen.getByText(/Error loading details for ID\s*err-id\s*:/), + ).toBeInTheDocument(); + }); + + test("renders not found message when completion is null and not loading/error", () => { + render( + , + ); + expect( + screen.getByText("No details found for completion ID: notfound-id."), + ).toBeInTheDocument(); + }); + + test("renders input, output, and properties for valid completion", () => { + const mockCompletion: ChatCompletion = { + id: "comp_123", + object: "chat.completion", + created: 1710000000, + model: "llama-test-model", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Test output" }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Test input" }], + }; + render( + , + ); + // Input + expect(screen.getByText("Input")).toBeInTheDocument(); + expect(screen.getByText("Test input")).toBeInTheDocument(); + // Output + expect(screen.getByText("Output")).toBeInTheDocument(); + expect(screen.getByText("Test output")).toBeInTheDocument(); + // Properties + expect(screen.getByText("Properties")).toBeInTheDocument(); + expect(screen.getByText("Created:")).toBeInTheDocument(); + expect( + screen.getByText(new Date(1710000000 * 1000).toLocaleString()), + ).toBeInTheDocument(); + expect(screen.getByText("ID:")).toBeInTheDocument(); + expect(screen.getByText("comp_123")).toBeInTheDocument(); + expect(screen.getByText("Model:")).toBeInTheDocument(); + expect(screen.getByText("llama-test-model")).toBeInTheDocument(); + expect(screen.getByText("Finish Reason:")).toBeInTheDocument(); + expect(screen.getByText("stop")).toBeInTheDocument(); + }); + + test("renders tool call in output and properties when present", () => { + const toolCall = { + function: { name: "search", arguments: '{"query":"llama"}' }, + }; + const mockCompletion: ChatCompletion = { + id: "comp_tool", + object: "chat.completion", + created: 1710001000, + model: "llama-tool-model", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: "Tool output", + tool_calls: [toolCall], + }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Tool input" }], + }; + render( + , + ); + // Output should include the tool call block (should be present twice: input and output) + const toolCallLabels = screen.getAllByText("Tool Call"); + expect(toolCallLabels.length).toBeGreaterThanOrEqual(1); // At least one, but could be two + // The tool call block should contain the formatted tool call string in both input and output + const toolCallBlocks = screen.getAllByText('search({"query":"llama"})'); + expect(toolCallBlocks.length).toBe(2); + // Properties should include the tool call name + expect(screen.getByText("Functions/Tools Called:")).toBeInTheDocument(); + expect(screen.getByText("search")).toBeInTheDocument(); + }); + + test("handles missing/empty fields gracefully", () => { + const mockCompletion: ChatCompletion = { + id: "comp_edge", + object: "chat.completion", + created: 1710002000, + model: "llama-edge-model", + choices: [], // No choices + input_messages: [], // No input messages + }; + render( + , + ); + // Input section should be present but empty + expect(screen.getByText("Input")).toBeInTheDocument(); + // Output section should show fallback message + expect( + screen.getByText("No message found in assistant's choice."), + ).toBeInTheDocument(); + // Properties should show N/A for finish reason + expect(screen.getByText("Finish Reason:")).toBeInTheDocument(); + expect(screen.getByText("N/A")).toBeInTheDocument(); + }); +}); diff --git a/llama_stack/ui/components/chat-completions/chat-completion-detail.tsx b/llama_stack/ui/components/chat-completions/chat-completion-detail.tsx new file mode 100644 index 000000000..e76418d1a --- /dev/null +++ b/llama_stack/ui/components/chat-completions/chat-completion-detail.tsx @@ -0,0 +1,198 @@ +"use client"; + +import { ChatMessage, ChatCompletion } from "@/lib/types"; +import { ChatMessageItem } from "@/components/chat-completions/chat-messasge-item"; +import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; +import { Skeleton } from "@/components/ui/skeleton"; + +function ChatCompletionDetailLoadingView() { + return ( + <> + {/* Title Skeleton */} +
+
+ {[...Array(2)].map((_, i) => ( + + + + + + + + + + + + + ))} +
+
+
+ {" "} + {/* Properties Title Skeleton */} + {[...Array(5)].map((_, i) => ( +
+ + +
+ ))} +
+
+
+ + ); +} + +interface ChatCompletionDetailViewProps { + completion: ChatCompletion | null; + isLoading: boolean; + error: Error | null; + id: string; +} + +export function ChatCompletionDetailView({ + completion, + isLoading, + error, + id, +}: ChatCompletionDetailViewProps) { + if (error) { + return ( + <> + {/* We still want a title for consistency on error pages */} +

Chat Completion Details

+

+ Error loading details for ID {id}: {error.message} +

+ + ); + } + + if (isLoading) { + return ; + } + + if (!completion) { + // This state means: not loading, no error, but no completion data + return ( + <> + {/* We still want a title for consistency on not-found pages */} +

Chat Completion Details

+

No details found for completion ID: {id}.

+ + ); + } + + // If no error, not loading, and completion exists, render the details: + return ( + <> +

Chat Completion Details

+
+
+ + + Input + + + {completion.input_messages?.map((msg, index) => ( + + ))} + {completion.choices?.[0]?.message?.tool_calls && + !completion.input_messages?.some( + (im) => + im.role === "assistant" && + im.tool_calls && + im.tool_calls.length > 0, + ) && + completion.choices[0].message.tool_calls.map( + (toolCall: any, index: number) => { + const assistantToolCallMessage: ChatMessage = { + role: "assistant", + tool_calls: [toolCall], + content: "", // Ensure content is defined, even if empty + }; + return ( + + ); + }, + )} + + + + + + Output + + + {completion.choices?.[0]?.message ? ( + + ) : ( +

+ No message found in assistant's choice. +

+ )} +
+
+
+ +
+ + + Properties + + +
    +
  • + Created:{" "} + + {new Date(completion.created * 1000).toLocaleString()} + +
  • +
  • + ID:{" "} + + {completion.id} + +
  • +
  • + Model:{" "} + + {completion.model} + +
  • +
  • + Finish Reason:{" "} + + {completion.choices?.[0]?.finish_reason || "N/A"} + +
  • + {completion.choices?.[0]?.message?.tool_calls && + completion.choices[0].message.tool_calls.length > 0 && ( +
  • + Functions/Tools Called: +
      + {completion.choices[0].message.tool_calls.map( + (toolCall: any, index: number) => ( +
    • + + {toolCall.function?.name || "N/A"} + +
    • + ), + )} +
    +
  • + )} +
+
+
+
+
+ + ); +} diff --git a/llama_stack/ui/components/chat-completions/chat-completion-table.test.tsx b/llama_stack/ui/components/chat-completions/chat-completion-table.test.tsx new file mode 100644 index 000000000..e71ef3d43 --- /dev/null +++ b/llama_stack/ui/components/chat-completions/chat-completion-table.test.tsx @@ -0,0 +1,340 @@ +import React from "react"; +import { render, screen, fireEvent } from "@testing-library/react"; +import "@testing-library/jest-dom"; +import { ChatCompletionsTable } from "./chat-completion-table"; +import { ChatCompletion } from "@/lib/types"; // Assuming this path is correct + +// Mock next/navigation +const mockPush = jest.fn(); +jest.mock("next/navigation", () => ({ + useRouter: () => ({ + push: mockPush, + }), +})); + +// Mock helper functions +// These are hoisted, so their mocks are available throughout the file +jest.mock("@/lib/truncate-text"); +jest.mock("@/lib/format-tool-call"); + +// Import the mocked functions to set up default or specific implementations +import { truncateText as originalTruncateText } from "@/lib/truncate-text"; +import { formatToolCallToString as originalFormatToolCallToString } from "@/lib/format-tool-call"; + +// Cast to jest.Mock for typings +const truncateText = originalTruncateText as jest.Mock; +const formatToolCallToString = originalFormatToolCallToString as jest.Mock; + +describe("ChatCompletionsTable", () => { + const defaultProps = { + completions: [] as ChatCompletion[], + isLoading: false, + error: null, + }; + + beforeEach(() => { + // Reset all mocks before each test + mockPush.mockClear(); + truncateText.mockClear(); + formatToolCallToString.mockClear(); + + // Default pass-through implementation for tests not focusing on truncation/formatting + truncateText.mockImplementation((text: string | undefined) => text); + formatToolCallToString.mockImplementation((toolCall: any) => + toolCall && typeof toolCall === "object" && toolCall.name + ? `[DefaultToolCall:${toolCall.name}]` + : "[InvalidToolCall]", + ); + }); + + test("renders without crashing with default props", () => { + render(); + // Check for a unique element that should be present in the non-empty, non-loading, non-error state + // For now, as per Task 1, we will test the empty state message + expect(screen.getByText("No chat completions found.")).toBeInTheDocument(); + }); + + test("click on a row navigates to the correct URL", () => { + const { rerender } = render(); + + // Simulate a scenario where a completion exists and is clicked + const mockCompletion: ChatCompletion = { + id: "comp_123", + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: "llama-test-model", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Test output" }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Test input" }], + }; + + rerender( + , + ); + const row = screen.getByText("Test input").closest("tr"); + if (row) { + fireEvent.click(row); + expect(mockPush).toHaveBeenCalledWith("/logs/chat-completions/comp_123"); + } else { + throw new Error('Row with "Test input" not found for router mock test.'); + } + }); + + describe("Loading State", () => { + test("renders skeleton UI when isLoading is true", () => { + const { container } = render( + , + ); + + // The Skeleton component uses data-slot="skeleton" + const skeletonSelector = '[data-slot="skeleton"]'; + + // Check for skeleton in the table caption + const tableCaption = container.querySelector("caption"); + expect(tableCaption).toBeInTheDocument(); + if (tableCaption) { + const captionSkeleton = tableCaption.querySelector(skeletonSelector); + expect(captionSkeleton).toBeInTheDocument(); + } + + // Check for skeletons in the table body cells + const tableBody = container.querySelector("tbody"); + expect(tableBody).toBeInTheDocument(); + if (tableBody) { + const bodySkeletons = tableBody.querySelectorAll( + `td ${skeletonSelector}`, + ); + expect(bodySkeletons.length).toBeGreaterThan(0); // Ensure at least one skeleton cell exists + } + + // General check: ensure multiple skeleton elements are present in the table overall + const allSkeletonsInTable = container.querySelectorAll( + `table ${skeletonSelector}`, + ); + expect(allSkeletonsInTable.length).toBeGreaterThan(3); // e.g., caption + at least one row of 3 cells, or just a few + }); + }); + + describe("Error State", () => { + test("renders error message when error prop is provided", () => { + const errorMessage = "Network Error"; + render( + , + ); + expect( + screen.getByText(`Error fetching data: ${errorMessage}`), + ).toBeInTheDocument(); + }); + + test("renders default error message when error.message is not available", () => { + render( + , + ); // Error with empty message + expect( + screen.getByText("Error fetching data: An unknown error occurred"), + ).toBeInTheDocument(); + }); + + test("renders default error message when error prop is an object without message", () => { + render(); // Empty error object + expect( + screen.getByText("Error fetching data: An unknown error occurred"), + ).toBeInTheDocument(); + }); + }); + + describe("Empty State", () => { + test('renders "No chat completions found." and no table when completions array is empty', () => { + render( + , + ); + expect( + screen.getByText("No chat completions found."), + ).toBeInTheDocument(); + + // Ensure that the table structure is NOT rendered in the empty state + const table = screen.queryByRole("table"); + expect(table).not.toBeInTheDocument(); + }); + }); + + describe("Data Rendering", () => { + test("renders table caption, headers, and completion data correctly", () => { + const mockCompletions = [ + { + id: "comp_1", + object: "chat.completion", + created: 1710000000, // Fixed timestamp for test + model: "llama-test-model", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Test output" }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Test input" }], + }, + { + id: "comp_2", + object: "chat.completion", + created: 1710001000, + model: "llama-another-model", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Another output" }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Another input" }], + }, + ]; + + render( + , + ); + + // Table caption + expect( + screen.getByText("A list of your recent chat completions."), + ).toBeInTheDocument(); + + // Table headers + expect(screen.getByText("Input")).toBeInTheDocument(); + expect(screen.getByText("Output")).toBeInTheDocument(); + expect(screen.getByText("Model")).toBeInTheDocument(); + expect(screen.getByText("Created")).toBeInTheDocument(); + + // Data rows + expect(screen.getByText("Test input")).toBeInTheDocument(); + expect(screen.getByText("Test output")).toBeInTheDocument(); + expect(screen.getByText("llama-test-model")).toBeInTheDocument(); + expect( + screen.getByText(new Date(1710000000 * 1000).toLocaleString()), + ).toBeInTheDocument(); + + expect(screen.getByText("Another input")).toBeInTheDocument(); + expect(screen.getByText("Another output")).toBeInTheDocument(); + expect(screen.getByText("llama-another-model")).toBeInTheDocument(); + expect( + screen.getByText(new Date(1710001000 * 1000).toLocaleString()), + ).toBeInTheDocument(); + }); + }); + + describe("Text Truncation and Tool Call Formatting", () => { + test("truncates long input and output text", () => { + // Specific mock implementation for this test + truncateText.mockImplementation( + (text: string | undefined, maxLength?: number) => { + const defaultTestMaxLength = 10; + const effectiveMaxLength = maxLength ?? defaultTestMaxLength; + return typeof text === "string" && text.length > effectiveMaxLength + ? text.slice(0, effectiveMaxLength) + "..." + : text; + }, + ); + + const longInput = + "This is a very long input message that should be truncated."; + const longOutput = + "This is a very long output message that should also be truncated."; + const mockCompletions = [ + { + id: "comp_trunc", + object: "chat.completion", + created: 1710002000, + model: "llama-trunc-model", + choices: [ + { + index: 0, + message: { role: "assistant", content: longOutput }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: longInput }], + }, + ]; + + render( + , + ); + + // The truncated text should be present for both input and output + const truncatedTexts = screen.getAllByText( + longInput.slice(0, 10) + "...", + ); + expect(truncatedTexts.length).toBe(2); // one for input, one for output + // Optionally, verify each one is in the document if getAllByText doesn't throw on not found + truncatedTexts.forEach((textElement) => + expect(textElement).toBeInTheDocument(), + ); + }); + + test("formats tool call output using formatToolCallToString", () => { + // Specific mock implementation for this test + formatToolCallToString.mockImplementation( + (toolCall: any) => `[TOOL:${toolCall.name}]`, + ); + // Ensure no truncation interferes for this specific test for clarity of tool call format + truncateText.mockImplementation((text: string | undefined) => text); + + const toolCall = { name: "search", args: { query: "llama" } }; + const mockCompletions = [ + { + id: "comp_tool", + object: "chat.completion", + created: 1710003000, + model: "llama-tool-model", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: "Tool output", // Content that will be prepended + tool_calls: [toolCall], + }, + finish_reason: "stop", + }, + ], + input_messages: [{ role: "user", content: "Tool input" }], + }, + ]; + + render( + , + ); + + // The component concatenates message.content and the formatted tool call + expect(screen.getByText("Tool output [TOOL:search]")).toBeInTheDocument(); + }); + }); +}); diff --git a/llama_stack/ui/components/chat-completions/chat-completion-table.tsx b/llama_stack/ui/components/chat-completions/chat-completion-table.tsx new file mode 100644 index 000000000..e11acf376 --- /dev/null +++ b/llama_stack/ui/components/chat-completions/chat-completion-table.tsx @@ -0,0 +1,120 @@ +"use client"; + +import { useRouter } from "next/navigation"; +import { ChatCompletion } from "@/lib/types"; +import { truncateText } from "@/lib/truncate-text"; +import { + extractTextFromContentPart, + extractDisplayableText, +} from "@/lib/format-message-content"; +import { + Table, + TableBody, + TableCaption, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table"; +import { Skeleton } from "@/components/ui/skeleton"; + +interface ChatCompletionsTableProps { + completions: ChatCompletion[]; + isLoading: boolean; + error: Error | null; +} + +export function ChatCompletionsTable({ + completions, + isLoading, + error, +}: ChatCompletionsTableProps) { + const router = useRouter(); + + const tableHeader = ( + + + Input + Output + Model + Created + + + ); + + if (isLoading) { + return ( + + + + + {tableHeader} + + {[...Array(3)].map((_, i) => ( + + + + + + + + + + + + + + + ))} + +
+ ); + } + + if (error) { + return ( +

Error fetching data: {error.message || "An unknown error occurred"}

+ ); + } + + if (completions.length === 0) { + return

No chat completions found.

; + } + + return ( + + A list of your recent chat completions. + {tableHeader} + + {completions.map((completion) => ( + + router.push(`/logs/chat-completions/${completion.id}`) + } + className="cursor-pointer hover:bg-muted/50" + > + + {truncateText( + extractTextFromContentPart( + completion.input_messages?.[0]?.content, + ), + )} + + + {(() => { + const message = completion.choices?.[0]?.message; + const outputText = extractDisplayableText(message); + return truncateText(outputText); + })()} + + {completion.model} + + {new Date(completion.created * 1000).toLocaleString()} + + + ))} + +
+ ); +} diff --git a/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx b/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx new file mode 100644 index 000000000..58a009aed --- /dev/null +++ b/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx @@ -0,0 +1,107 @@ +"use client"; + +import { ChatMessage } from "@/lib/types"; +import React from "react"; +import { formatToolCallToString } from "@/lib/format-tool-call"; +import { extractTextFromContentPart } from "@/lib/format-message-content"; + +// Sub-component or helper for the common label + content structure +const MessageBlock: React.FC<{ + label: string; + labelDetail?: string; + content: React.ReactNode; +}> = ({ label, labelDetail, content }) => { + return ( +
+

+ {label} + {labelDetail && ( + + {labelDetail} + + )} +

+
{content}
+
+ ); +}; + +interface ToolCallBlockProps { + children: React.ReactNode; + className?: string; +} + +const ToolCallBlock = ({ children, className }: ToolCallBlockProps) => { + // Common styling for both function call arguments and tool output blocks + // Let's use slate-50 background as it's good for code-like content. + const baseClassName = + "p-3 bg-slate-50 border border-slate-200 rounded-md text-sm"; + + return ( +
+
{children}
+
+ ); +}; + +interface ChatMessageItemProps { + message: ChatMessage; +} +export function ChatMessageItem({ message }: ChatMessageItemProps) { + switch (message.role) { + case "system": + return ( + + ); + case "user": + return ( + + ); + + case "assistant": + if (message.tool_calls && message.tool_calls.length > 0) { + return ( + <> + {message.tool_calls.map((toolCall: any, index: number) => { + const formattedToolCall = formatToolCallToString(toolCall); + const toolCallContent = ( + + {formattedToolCall || "Error: Could not display tool call"} + + ); + return ( + + ); + })} + + ); + } else { + return ( + + ); + } + case "tool": + const toolOutputContent = ( + + {extractTextFromContentPart(message.content)} + + ); + return ( + + ); + } + return null; +} diff --git a/llama_stack/ui/components/app-sidebar.tsx b/llama_stack/ui/components/layout/app-sidebar.tsx similarity index 50% rename from llama_stack/ui/components/app-sidebar.tsx rename to llama_stack/ui/components/layout/app-sidebar.tsx index 3d541856f..1c53d6cc5 100644 --- a/llama_stack/ui/components/app-sidebar.tsx +++ b/llama_stack/ui/components/layout/app-sidebar.tsx @@ -1,5 +1,9 @@ +"use client"; + import { MessageSquareText, MessagesSquare, MoveUpRight } from "lucide-react"; import Link from "next/link"; +import { usePathname } from "next/navigation"; +import { cn } from "@/lib/utils"; import { Sidebar, @@ -32,6 +36,8 @@ const logItems = [ ]; export function AppSidebar() { + const pathname = usePathname(); + return ( @@ -42,16 +48,31 @@ export function AppSidebar() { Logs - {logItems.map((item) => ( - - - - - {item.title} - - - - ))} + {logItems.map((item) => { + const isActive = pathname.startsWith(item.url); + return ( + + + + + {item.title} + + + + ); + })} diff --git a/llama_stack/ui/components/layout/page-breadcrumb.tsx b/llama_stack/ui/components/layout/page-breadcrumb.tsx new file mode 100644 index 000000000..fdb561d68 --- /dev/null +++ b/llama_stack/ui/components/layout/page-breadcrumb.tsx @@ -0,0 +1,49 @@ +"use client"; + +import Link from "next/link"; +import React from "react"; +import { + Breadcrumb, + BreadcrumbItem, + BreadcrumbLink, + BreadcrumbList, + BreadcrumbPage, + BreadcrumbSeparator, +} from "@/components/ui/breadcrumb"; + +export interface BreadcrumbSegment { + label: string; + href?: string; +} + +interface PageBreadcrumbProps { + segments: BreadcrumbSegment[]; + className?: string; +} + +export function PageBreadcrumb({ segments, className }: PageBreadcrumbProps) { + if (!segments || segments.length === 0) { + return null; + } + + return ( + + + {segments.map((segment, index) => ( + + + {segment.href ? ( + + {segment.label} + + ) : ( + {segment.label} + )} + + {index < segments.length - 1 && } + + ))} + + + ); +} diff --git a/llama_stack/ui/components/ui/breadcrumb.tsx b/llama_stack/ui/components/ui/breadcrumb.tsx new file mode 100644 index 000000000..f63ae19af --- /dev/null +++ b/llama_stack/ui/components/ui/breadcrumb.tsx @@ -0,0 +1,109 @@ +import * as React from "react"; +import { Slot } from "@radix-ui/react-slot"; +import { ChevronRight, MoreHorizontal } from "lucide-react"; + +import { cn } from "@/lib/utils"; + +function Breadcrumb({ ...props }: React.ComponentProps<"nav">) { + return