diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 6e84d94e0..6787806e9 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -18,7 +18,7 @@ on: - '.github/workflows/integration-auth-tests.yml' # This workflow concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: diff --git a/.github/workflows/integration-sql-store-tests.yml b/.github/workflows/integration-sql-store-tests.yml index 485e546fa..3efd970e1 100644 --- a/.github/workflows/integration-sql-store-tests.yml +++ b/.github/workflows/integration-sql-store-tests.yml @@ -16,7 +16,7 @@ on: - '.github/workflows/integration-sql-store-tests.yml' # This workflow concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 194c362c4..5f13620f7 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -8,7 +8,7 @@ on: branches: [main] concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -36,20 +36,16 @@ jobs: **/requirements*.txt .pre-commit-config.yaml - # npm ci may fail - - # npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing. - # npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18 + - name: Set up Node.js + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 + with: + node-version: '20' + cache: 'npm' + cache-dependency-path: 'llama_stack/ui/' - # - name: Set up Node.js - # uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0 - # with: - # node-version: '20' - # cache: 'npm' - # cache-dependency-path: 'llama_stack/ui/' - - # - name: Install npm dependencies - # run: npm ci - # working-directory: llama_stack/ui + - name: Install npm dependencies + run: npm ci + working-directory: llama_stack/ui - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 continue-on-error: true diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 461c25148..391acbcf8 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -26,7 +26,7 @@ on: - 'pyproject.toml' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -106,6 +106,10 @@ jobs: - name: Inspect the container image entrypoint run: | IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + if [ -z "$IMAGE_ID" ]; then + echo "No image found" + exit 1 + fi entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) echo "Entrypoint: $entrypoint" if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then @@ -140,6 +144,10 @@ jobs: - name: Inspect UBI9 image run: | IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + if [ -z "$IMAGE_ID" ]; then + echo "No image found" + exit 1 + fi entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) echo "Entrypoint: $entrypoint" if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then diff --git a/.github/workflows/python-build-test.yml b/.github/workflows/python-build-test.yml index 9de53f7fb..bf9a3e057 100644 --- a/.github/workflows/python-build-test.yml +++ b/.github/workflows/python-build-test.yml @@ -24,7 +24,7 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install uv - uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0 + uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0 with: python-version: ${{ matrix.python-version }} activate-environment: true diff --git a/.github/workflows/semantic-pr.yml b/.github/workflows/semantic-pr.yml index 4adaca84d..4a078fa00 100644 --- a/.github/workflows/semantic-pr.yml +++ b/.github/workflows/semantic-pr.yml @@ -22,6 +22,6 @@ jobs: runs-on: ubuntu-latest steps: - name: Check PR Title's semantic conformance - uses: amannn/action-semantic-pull-request@7f33ba792281b034f64e96f4c0b5496782dd3b37 # v6.1.0 + uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ui-unit-tests.yml b/.github/workflows/ui-unit-tests.yml index 4b0d62e90..2afb92bee 100644 --- a/.github/workflows/ui-unit-tests.yml +++ b/.github/workflows/ui-unit-tests.yml @@ -13,7 +13,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index cce8d9ff6..dd2097a45 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -18,7 +18,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: diff --git a/.github/workflows/update-readthedocs.yml b/.github/workflows/update-readthedocs.yml index 9ed89a271..e12f0adf8 100644 --- a/.github/workflows/update-readthedocs.yml +++ b/.github/workflows/update-readthedocs.yml @@ -27,7 +27,7 @@ on: - '.github/workflows/update-readthedocs.yml' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d25455cf0..514fe6d2e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -146,31 +146,13 @@ repos: pass_filenames: false require_serial: true files: ^.github/workflows/.*$ - # ui-prettier and ui-eslint are disabled until we can avoid `npm ci`, which is slow and may fail - - # npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing. - # npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18 - # and until we have infra for installing prettier and next via npm - - # Lint UI code with ESLint.....................................................Failed - # - hook id: ui-eslint - # - exit code: 127 - # > ui@0.1.0 lint - # > next lint --fix --quiet - # sh: line 1: next: command not found - # - # - id: ui-prettier - # name: Format UI code with Prettier - # entry: bash -c 'cd llama_stack/ui && npm ci && npm run format' - # language: system - # files: ^llama_stack/ui/.*\.(ts|tsx)$ - # pass_filenames: false - # require_serial: true - # - id: ui-eslint - # name: Lint UI code with ESLint - # entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet' - # language: system - # files: ^llama_stack/ui/.*\.(ts|tsx)$ - # pass_filenames: false - # require_serial: true + - id: ui-linter + name: Format & Lint UI + entry: bash ./scripts/run-ui-linter.sh + language: system + files: ^llama_stack/ui/.*\.(ts|tsx)$ + pass_filenames: false + require_serial: true - id: check-log-usage name: Ensure 'llama_stack.log' usage for logging diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b36626719..a1f6a6f30 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4605,6 +4605,49 @@ } } }, + "/v1/inference/rerank": { + "post": { + "responses": { + "200": { + "description": "RerankResponse with indices sorted by relevance score (descending).", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RerankResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Inference" + ], + "description": "Rerank a list of documents based on their relevance to a query.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RerankRequest" + } + } + }, + "required": true + } + } + }, "/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": { "post": { "responses": { @@ -16024,12 +16067,16 @@ "value": { "type": "number", "description": "The numeric value of the metric at this timestamp" + }, + "unit": { + "type": "string" } }, "additionalProperties": false, "required": [ "timestamp", - "value" + "value", + "unit" ], "title": "MetricDataPoint", "description": "A single data point in a metric time series." @@ -16587,6 +16634,95 @@ ], "title": "RegisterVectorDbRequest" }, + "RerankRequest": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The identifier of the reranking model to use." + }, + "query": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" + } + ], + "description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length." + }, + "items": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" + } + ] + }, + "description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length." + }, + "max_num_results": { + "type": "integer", + "description": "(Optional) Maximum number of results to return. Default: returns all." + } + }, + "additionalProperties": false, + "required": [ + "model", + "query", + "items" + ], + "title": "RerankRequest" + }, + "RerankData": { + "type": "object", + "properties": { + "index": { + "type": "integer", + "description": "The original index of the document in the input list" + }, + "relevance_score": { + "type": "number", + "description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance." + } + }, + "additionalProperties": false, + "required": [ + "index", + "relevance_score" + ], + "title": "RerankData", + "description": "A single rerank result from a reranking response." + }, + "RerankResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RerankData" + }, + "description": "List of rerank result objects, sorted by relevance score (descending)" + } + }, + "additionalProperties": false, + "required": [ + "data" + ], + "title": "RerankResponse", + "description": "Response from a reranking request." + }, "ResumeAgentTurnRequest": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index e7733b3c3..33142e3ff 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3264,6 +3264,37 @@ paths: schema: $ref: '#/components/schemas/QueryTracesRequest' required: true + /v1/inference/rerank: + post: + responses: + '200': + description: >- + RerankResponse with indices sorted by relevance score (descending). + content: + application/json: + schema: + $ref: '#/components/schemas/RerankResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Inference + description: >- + Rerank a list of documents based on their relevance to a query. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RerankRequest' + required: true /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume: post: responses: @@ -11923,10 +11954,13 @@ components: type: number description: >- The numeric value of the metric at this timestamp + unit: + type: string additionalProperties: false required: - timestamp - value + - unit title: MetricDataPoint description: >- A single data point in a metric time series. @@ -12337,6 +12371,76 @@ components: - vector_db_id - embedding_model title: RegisterVectorDbRequest + RerankRequest: + type: object + properties: + model: + type: string + description: >- + The identifier of the reranking model to use. + query: + oneOf: + - type: string + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' + description: >- + The search query to rank items against. Can be a string, text content + part, or image content part. The input must not exceed the model's max + input token length. + items: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' + description: >- + List of items to rerank. Each item can be a string, text content part, + or image content part. Each input must not exceed the model's max input + token length. + max_num_results: + type: integer + description: >- + (Optional) Maximum number of results to return. Default: returns all. + additionalProperties: false + required: + - model + - query + - items + title: RerankRequest + RerankData: + type: object + properties: + index: + type: integer + description: >- + The original index of the document in the input list + relevance_score: + type: number + description: >- + The relevance score from the model output. Values are inverted when applicable + so that higher scores indicate greater relevance. + additionalProperties: false + required: + - index + - relevance_score + title: RerankData + description: >- + A single rerank result from a reranking response. + RerankResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/RerankData' + description: >- + List of rerank result objects, sorted by relevance score (descending) + additionalProperties: false + required: + - data + title: RerankResponse + description: Response from a reranking request. ResumeAgentTurnRequest: type: object properties: diff --git a/docs/source/advanced_apis/evaluation_concepts.md b/docs/source/advanced_apis/evaluation_concepts.md index c26ec8f5e..52ad53ece 100644 --- a/docs/source/advanced_apis/evaluation_concepts.md +++ b/docs/source/advanced_apis/evaluation_concepts.md @@ -33,7 +33,7 @@ The list of open-benchmarks we currently support: - [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models. -You can follow this [contributing guide](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack +You can follow this [contributing guide](../references/evals_reference/index.md#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack #### Run evaluation on open-benchmarks via CLI diff --git a/docs/source/advanced_apis/post_training/inline_huggingface.md b/docs/source/advanced_apis/post_training/inline_huggingface.md index 4d2201c99..6536b4f8c 100644 --- a/docs/source/advanced_apis/post_training/inline_huggingface.md +++ b/docs/source/advanced_apis/post_training/inline_huggingface.md @@ -35,3 +35,6 @@ device: cpu ``` +[Find more detailed information here!](huggingface.md) + + diff --git a/docs/source/advanced_apis/post_training/inline_torchtune.md b/docs/source/advanced_apis/post_training/inline_torchtune.md index 6684c99ac..617975b0d 100644 --- a/docs/source/advanced_apis/post_training/inline_torchtune.md +++ b/docs/source/advanced_apis/post_training/inline_torchtune.md @@ -22,3 +22,4 @@ checkpoint_format: meta ``` +[Find more detailed information here!](torchtune.md) diff --git a/docs/source/building_applications/playground/index.md b/docs/source/building_applications/playground/index.md index fd2b92434..2390c422f 100644 --- a/docs/source/building_applications/playground/index.md +++ b/docs/source/building_applications/playground/index.md @@ -88,7 +88,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie - **API Resources**: Inspect Llama Stack API resources - This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `benchmarks`, `shields`). - Under the hood, it uses Llama Stack's `//list` API to get information about each resources. - - Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources. + - Please visit [Core Concepts](../../concepts/index.md) for more details about the resources. ### Starting the Llama Stack Playground diff --git a/docs/source/building_applications/responses_vs_agents.md b/docs/source/building_applications/responses_vs_agents.md index 5abe951d6..63ff69e4f 100644 --- a/docs/source/building_applications/responses_vs_agents.md +++ b/docs/source/building_applications/responses_vs_agents.md @@ -3,7 +3,7 @@ Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics. ```{note} -For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API. + **Note:** For simple and basic inferencing, you may want to use the [Chat Completions API](../providers/openai.md#chat-completions) directly, before progressing to Agents or Responses API. ``` ## Overview @@ -173,7 +173,7 @@ Both APIs demonstrate distinct strengths that make them valuable on their own fo ## For More Information -- **LLS Agents API**: For detailed information on creating and managing agents, see the [Agents documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html) +- **LLS Agents API**: For detailed information on creating and managing agents, see the [Agents documentation](agent.md) - **OpenAI Responses API**: For information on using the OpenAI-compatible responses API, see the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/responses) -- **Chat Completions API**: For the default backend API used by Agents, see the [Chat Completions providers documentation](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) -- **Agent Execution Loop**: For understanding how agents process turns and steps in their execution, see the [Agent Execution Loop documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent_execution_loop.html) +- **Chat Completions API**: For the default backend API used by Agents, see the [Chat Completions providers documentation](../providers/openai.md#chat-completions) +- **Agent Execution Loop**: For understanding how agents process turns and steps in their execution, see the [Agent Execution Loop documentation](agent_execution_loop.md) diff --git a/docs/source/concepts/distributions.md b/docs/source/concepts/distributions.md index c3be12d93..8c63914d1 100644 --- a/docs/source/concepts/distributions.md +++ b/docs/source/concepts/distributions.md @@ -6,4 +6,4 @@ While there is a lot of flexibility to mix-and-match providers, often users will **Locally Hosted Distro**: You may want to run Llama Stack on your own hardware. Typically though, you still need to use Inference via an external service. You can use providers like HuggingFace TGI, Fireworks, Together, etc. for this purpose. Or you may have access to GPUs and can run a [vLLM](https://github.com/vllm-project/vllm) or [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) instance. If you "just" have a regular desktop machine, you can use [Ollama](https://ollama.com/) for inference. To provide convenient quick access to these options, we provide a number of such pre-configured locally-hosted Distros. -**On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/ios_sdk.html) and [Android](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/android_sdk.html) +**On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](../distributions/ondevice_distro/ios_sdk.md) and [Android](../distributions/ondevice_distro/android_sdk.md) diff --git a/docs/source/contributing/new_api_provider.md b/docs/source/contributing/new_api_provider.md index 6f8f59a47..9a7a62a38 100644 --- a/docs/source/contributing/new_api_provider.md +++ b/docs/source/contributing/new_api_provider.md @@ -14,6 +14,13 @@ Here are some example PRs to help you get started: - [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355) - [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665) +## Guidelines for creating Internal or External Providers + +|**Type** |Internal (In-tree) |External (out-of-tree) +|---------|-------------------|---------------------| +|**Description** |A provider that is directly in the Llama Stack code|A provider that is outside of the Llama stack core codebase but is still accessible and usable by Llama Stack. +|**Benefits** |Ability to interact with the provider with minimal additional configurations or installations| Contributors do not have to add directly to the code to create providers accessible on Llama Stack. Keep provider-specific code separate from the core Llama Stack code. + ## Inference Provider Patterns When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers. diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 335fa3a68..c9677b3b6 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -225,8 +225,32 @@ server: port: 8321 # Port to listen on (default: 8321) tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS + cors: true # Optional: Enable CORS (dev mode) or full config object ``` +### CORS Configuration + +CORS (Cross-Origin Resource Sharing) can be configured in two ways: + +**Local development** (allows localhost origins only): +```yaml +server: + cors: true +``` + +**Explicit configuration** (custom origins and settings): +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://app.example.com"] + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_credentials: true + max_age: 3600 +``` + +When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security. + ### Authentication Configuration > **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. @@ -618,6 +642,54 @@ Content-Type: application/json } ``` +### CORS Configuration + +Configure CORS to allow web browsers to make requests from different domains. Disabled by default. + +#### Quick Setup + +For development, use the simple boolean flag: + +```yaml +server: + cors: true # Auto-enables localhost with any port +``` + +This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults. + +#### Custom Configuration + +For specific origins and full control: + +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://staging.myapp.com"] + allow_credentials: true + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern + expose_headers: ["X-Total-Count"] + max_age: 86400 +``` + +#### Configuration Options + +| Field | Description | Default | +| -------------------- | ---------------------------------------------- | ------- | +| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` | +| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` | +| `allow_methods` | Allowed HTTP methods. | `["*"]` | +| `allow_headers` | Allowed headers. | `["*"]` | +| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` | +| `expose_headers` | Headers exposed to browser. | `[]` | +| `max_age` | Preflight cache time (seconds). | `600` | + +**Security Notes**: +- `allow_credentials: true` requires explicit origins (no wildcards) +- `cors: true` enables localhost access only (secure for development) +- For public APIs, always specify exact allowed origins + ## Extending to handle Safety Configuring Safety can be a little involved so it is instructive to go through an example. diff --git a/docs/source/distributions/importing_as_library.md b/docs/source/distributions/importing_as_library.md index fbc48dd95..9993be227 100644 --- a/docs/source/distributions/importing_as_library.md +++ b/docs/source/distributions/importing_as_library.md @@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient( # provider_data is optional, but if you need to pass in any provider specific data, you can do so here. provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, ) -client.initialize() ``` This will parse your config and set up any inline implementations and remote clients needed for your implementation. @@ -28,9 +27,8 @@ Then, you can access the APIs like `models` and `inference` on the client and ca response = client.models.list() ``` -If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html), you can also use the run.yaml configuration file directly: +If you've created a [custom distribution](building_distro.md), you can also use the run.yaml configuration file directly: ```python client = LlamaStackAsLibraryClient(config_path) -client.initialize() ``` diff --git a/docs/source/distributions/k8s/apply.sh b/docs/source/distributions/k8s/apply.sh index 3356da53e..1b5b26863 100755 --- a/docs/source/distributions/k8s/apply.sh +++ b/docs/source/distributions/k8s/apply.sh @@ -22,17 +22,17 @@ else fi if [ -z "${GITHUB_CLIENT_ID:-}" ]; then - echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide" + echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation." exit 1 fi if [ -z "${GITHUB_CLIENT_SECRET:-}" ]; then - echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide" + echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation." exit 1 fi if [ -z "${LLAMA_STACK_UI_URL:-}" ]; then - echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide" + echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation." exit 1 fi diff --git a/docs/source/distributions/ondevice_distro/android_sdk.md b/docs/source/distributions/ondevice_distro/android_sdk.md index 9d16d07d7..ad86fa5f3 100644 --- a/docs/source/distributions/ondevice_distro/android_sdk.md +++ b/docs/source/distributions/ondevice_distro/android_sdk.md @@ -66,7 +66,7 @@ llama stack run starter --port 5050 Ensure the Llama Stack server version is the same as the Kotlin SDK Library for maximum compatibility. -Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations) +Other inference providers: [Table](../../index.md#supported-llama-stack-implementations) How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#settings) diff --git a/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md b/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md index 7e50a4161..84b85b91c 100644 --- a/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md +++ b/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md @@ -2,7 +2,7 @@ orphan: true --- -# Meta Reference Distribution +# Meta Reference GPU Distribution ```{toctree} :maxdepth: 2 @@ -41,7 +41,7 @@ The following environment variables can be configured: ## Prerequisite: Downloading Models -Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. +Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. ``` $ llama model list --downloaded diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md index 2a39a626c..d6d2fa9a3 100644 --- a/docs/source/providers/batches/index.md +++ b/docs/source/providers/batches/index.md @@ -2,12 +2,15 @@ ## Overview -Protocol for batch processing API operations. - - The Batches API enables efficient processing of multiple requests in a single operation, +The Batches API enables efficient processing of multiple requests in a single operation, particularly useful for processing large datasets, batch evaluation workflows, and cost-effective inference at scale. + The API is designed to allow use of openai client libraries for seamless integration. + + This API provides the following extensions: + - idempotent batch creation + Note: This API is currently under active development and may undergo changes. This section contains documentation for all available providers for the **batches** API. diff --git a/docs/source/providers/files/index.md b/docs/source/providers/files/index.md index 692aad3ca..128953223 100644 --- a/docs/source/providers/files/index.md +++ b/docs/source/providers/files/index.md @@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files* :maxdepth: 1 inline_localfs +remote_s3 ``` diff --git a/docs/source/providers/files/remote_s3.md b/docs/source/providers/files/remote_s3.md new file mode 100644 index 000000000..2e3cebabd --- /dev/null +++ b/docs/source/providers/files/remote_s3.md @@ -0,0 +1,33 @@ +# remote::s3 + +## Description + +AWS S3-based file storage provider for scalable cloud file management with metadata persistence. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `bucket_name` | `` | No | | S3 bucket name to store files | +| `region` | `` | No | us-east-1 | AWS region where the bucket is located | +| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) | +| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) | +| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) | +| `auto_create_bucket` | `` | No | False | Automatically create the S3 bucket if it doesn't exist | +| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata | + +## Sample Configuration + +```yaml +bucket_name: ${env.S3_BUCKET_NAME} +region: ${env.AWS_REGION:=us-east-1} +aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=} +aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=} +endpoint_url: ${env.S3_ENDPOINT_URL:=} +auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db + +``` + diff --git a/docs/source/providers/post_training/index.md b/docs/source/providers/post_training/index.md index c6c92c40e..e69f2a45a 100644 --- a/docs/source/providers/post_training/index.md +++ b/docs/source/providers/post_training/index.md @@ -9,7 +9,8 @@ This section contains documentation for all available providers for the **post_t ```{toctree} :maxdepth: 1 -inline_huggingface -inline_torchtune +inline_huggingface-gpu +inline_torchtune-cpu +inline_torchtune-gpu remote_nvidia ``` diff --git a/docs/source/providers/post_training/inline_huggingface-cpu.md b/docs/source/providers/post_training/inline_huggingface-cpu.md new file mode 100644 index 000000000..e663fe8f8 --- /dev/null +++ b/docs/source/providers/post_training/inline_huggingface-cpu.md @@ -0,0 +1,41 @@ +# inline::huggingface-cpu + +## Description + +HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `device` | `` | No | cuda | | +| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | | +| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | | +| `chat_template` | `` | No | <|user|> +{input} +<|assistant|> +{output} | | +| `model_specific_config` | `` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | | +| `max_seq_length` | `` | No | 2048 | | +| `gradient_checkpointing` | `` | No | False | | +| `save_total_limit` | `` | No | 3 | | +| `logging_steps` | `` | No | 10 | | +| `warmup_ratio` | `` | No | 0.1 | | +| `weight_decay` | `` | No | 0.01 | | +| `dataloader_num_workers` | `` | No | 4 | | +| `dataloader_pin_memory` | `` | No | True | | +| `dpo_beta` | `` | No | 0.1 | | +| `use_reference_model` | `` | No | True | | +| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | +| `dpo_output_dir` | `` | No | | | + +## Sample Configuration + +```yaml +checkpoint_format: huggingface +distributed_backend: null +device: cpu +dpo_output_dir: ~/.llama/dummy/dpo_output + +``` + diff --git a/docs/source/providers/post_training/inline_huggingface-gpu.md b/docs/source/providers/post_training/inline_huggingface-gpu.md new file mode 100644 index 000000000..21bf965fe --- /dev/null +++ b/docs/source/providers/post_training/inline_huggingface-gpu.md @@ -0,0 +1,41 @@ +# inline::huggingface-gpu + +## Description + +HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `device` | `` | No | cuda | | +| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | | +| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | | +| `chat_template` | `` | No | <|user|> +{input} +<|assistant|> +{output} | | +| `model_specific_config` | `` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | | +| `max_seq_length` | `` | No | 2048 | | +| `gradient_checkpointing` | `` | No | False | | +| `save_total_limit` | `` | No | 3 | | +| `logging_steps` | `` | No | 10 | | +| `warmup_ratio` | `` | No | 0.1 | | +| `weight_decay` | `` | No | 0.01 | | +| `dataloader_num_workers` | `` | No | 4 | | +| `dataloader_pin_memory` | `` | No | True | | +| `dpo_beta` | `` | No | 0.1 | | +| `use_reference_model` | `` | No | True | | +| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | +| `dpo_output_dir` | `` | No | | | + +## Sample Configuration + +```yaml +checkpoint_format: huggingface +distributed_backend: null +device: cpu +dpo_output_dir: ~/.llama/dummy/dpo_output + +``` + diff --git a/docs/source/providers/post_training/inline_torchtune-cpu.md b/docs/source/providers/post_training/inline_torchtune-cpu.md new file mode 100644 index 000000000..7204e56e8 --- /dev/null +++ b/docs/source/providers/post_training/inline_torchtune-cpu.md @@ -0,0 +1,20 @@ +# inline::torchtune-cpu + +## Description + +TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `torch_seed` | `int \| None` | No | | | +| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | | + +## Sample Configuration + +```yaml +checkpoint_format: meta + +``` + diff --git a/docs/source/providers/post_training/inline_torchtune-gpu.md b/docs/source/providers/post_training/inline_torchtune-gpu.md new file mode 100644 index 000000000..98b94f6f6 --- /dev/null +++ b/docs/source/providers/post_training/inline_torchtune-gpu.md @@ -0,0 +1,20 @@ +# inline::torchtune-gpu + +## Description + +TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `torch_seed` | `int \| None` | No | | | +| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | | + +## Sample Configuration + +```yaml +checkpoint_format: meta + +``` + diff --git a/docs/source/references/evals_reference/index.md b/docs/source/references/evals_reference/index.md index 054a0b809..9a5ed2f1b 100644 --- a/docs/source/references/evals_reference/index.md +++ b/docs/source/references/evals_reference/index.md @@ -202,7 +202,7 @@ pprint(response) Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets. -In this example, we will work with an example RAG dataset you have built previously, label with an annotation, and use LLM-As-Judge with custom judge prompt for scoring. Please checkout our [Llama Stack Playground](https://llama-stack.readthedocs.io/en/latest/playground/index.html) for an interactive interface to upload datasets and run scorings. +In this example, we will work with an example RAG dataset you have built previously, label with an annotation, and use LLM-As-Judge with custom judge prompt for scoring. Please checkout our [Llama Stack Playground](../../building_applications/playground/index.md) for an interactive interface to upload datasets and run scorings. ```python judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8" diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py index 9297d8597..c6bbd92eb 100644 --- a/llama_stack/apis/batches/batches.py +++ b/llama_stack/apis/batches/batches.py @@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel): @runtime_checkable class Batches(Protocol): - """Protocol for batch processing API operations. - + """ The Batches API enables efficient processing of multiple requests in a single operation, particularly useful for processing large datasets, batch evaluation workflows, and cost-effective inference at scale. + The API is designed to allow use of openai client libraries for seamless integration. + + This API provides the following extensions: + - idempotent batch creation + Note: This API is currently under active development and may undergo changes. """ @@ -45,6 +49,7 @@ class Batches(Protocol): endpoint: str, completion_window: Literal["24h"], metadata: dict[str, str] | None = None, + idempotency_key: str | None = None, ) -> BatchObject: """Create a new batch for processing multiple API requests. @@ -52,6 +57,7 @@ class Batches(Protocol): :param endpoint: The endpoint to be used for all requests in the batch. :param completion_window: The time window within which the batch should be processed. :param metadata: Optional metadata for the batch. + :param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior. :returns: The created batch object. """ ... diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7e7bd0a3d..bd4737ca7 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel): embeddings: list[list[float]] +@json_schema_type +class RerankData(BaseModel): + """A single rerank result from a reranking response. + + :param index: The original index of the document in the input list + :param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance. + """ + + index: int + relevance_score: float + + +@json_schema_type +class RerankResponse(BaseModel): + """Response from a reranking request. + + :param data: List of rerank result objects, sorted by relevance score (descending) + """ + + data: list[RerankData] + + @json_schema_type class OpenAIChatCompletionContentPartTextParam(BaseModel): """Text content part for OpenAI-compatible chat completion messages. @@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol): :returns: A BatchCompletionResponse with the full completions. """ raise NotImplementedError("Batch completion is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete @webmethod(route="/inference/chat-completion", method="POST") async def chat_completion( @@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol): :returns: A BatchChatCompletionResponse with the full completions. """ raise NotImplementedError("Batch chat completion is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete @webmethod(route="/inference/embeddings", method="POST") async def embeddings( @@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol): """ ... + @webmethod(route="/inference/rerank", method="POST", experimental=True) + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + """Rerank a list of documents based on their relevance to a query. + + :param model: The identifier of the reranking model to use. + :param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length. + :param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length. + :param max_num_results: (Optional) Maximum number of results to return. Default: returns all. + :returns: RerankResponse with indices sorted by relevance score (descending). + """ + raise NotImplementedError("Reranking is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete + @webmethod(route="/openai/v1/completions", method="POST") async def openai_completion( self, diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 92422ac1b..8d1b5d697 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel): timestamp: int value: float + unit: str @json_schema_type @@ -518,7 +519,7 @@ class Telemetry(Protocol): metric_name: str, start_time: int, end_time: int | None = None, - granularity: str | None = "1d", + granularity: str | None = None, query_type: MetricQueryType = MetricQueryType.RANGE, label_matchers: list[MetricLabelMatcher] | None = None, ) -> QueryMetricsResponse: diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index c8ffce034..b32b8b3ae 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -15,7 +15,7 @@ from llama_stack.log import get_logger REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = get_logger(name=__name__, category="server") +logger = get_logger(name=__name__, category="cli") class StackRun(Subcommand): diff --git a/llama_stack/core/build.py b/llama_stack/core/build.py index fa1fe632b..2ceb9e9be 100644 --- a/llama_stack/core/build.py +++ b/llama_stack/core/build.py @@ -80,7 +80,7 @@ def get_provider_dependencies( normal_deps = [] special_deps = [] for package in deps: - if "--no-deps" in package or "--index-url" in package: + if any(f in package for f in ["--no-deps", "--index-url", "--extra-index-url"]): special_deps.append(package) else: normal_deps.append(package) diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index a1b6ad32b..c3940fcbd 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -318,6 +318,41 @@ class QuotaConfig(BaseModel): period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") +class CORSConfig(BaseModel): + allow_origins: list[str] = Field(default_factory=list) + allow_origin_regex: str | None = Field(default=None) + allow_methods: list[str] = Field(default=["OPTIONS"]) + allow_headers: list[str] = Field(default_factory=list) + allow_credentials: bool = Field(default=False) + expose_headers: list[str] = Field(default_factory=list) + max_age: int = Field(default=600, ge=0) + + @model_validator(mode="after") + def validate_credentials_config(self) -> Self: + if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins): + raise ValueError("Cannot use wildcard origins with credentials enabled") + return self + + +def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None: + if cors_config is False or cors_config is None: + return None + + if cors_config is True: + # dev mode: allow localhost on any port + return CORSConfig( + allow_origins=[], + allow_origin_regex=r"https?://localhost:\d+", + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-Requested-With"], + ) + + if isinstance(cors_config, CORSConfig): + return cors_config + + raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}") + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -349,6 +384,12 @@ class ServerConfig(BaseModel): default=None, description="Per client quota request configuration", ) + cors: bool | CORSConfig | None = Field( + default=None, + description="CORS configuration for cross-origin requests. Can be:\n" + "- true: Enable localhost CORS for development\n" + "- {allow_origins: [...], allow_methods: [...], ...}: Full configuration", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index dd1fc8a50..9e7a8006c 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -146,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_distro_name, custom_provider_registry, provider_data + config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal ) self.pool_executor = ThreadPoolExecutor(max_workers=4) - self.skip_logger_removal = skip_logger_removal self.provider_data = provider_data self.loop = asyncio.new_event_loop() - def initialize(self): - if in_notebook(): - import nest_asyncio - - nest_asyncio.apply() - if not self.skip_logger_removal: - self._remove_root_logger_handlers() - # use a new event loop to avoid interfering with the main event loop loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - return loop.run_until_complete(self.async_client.initialize()) + loop.run_until_complete(self.async_client.initialize()) finally: asyncio.set_event_loop(None) - def _remove_root_logger_handlers(self): + def initialize(self): """ - Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + Deprecated method for backward compatibility. """ - root_logger = logging.getLogger() - - for handler in root_logger.handlers[:]: - root_logger.removeHandler(handler) - logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + pass def request(self, *args, **kwargs): loop = self.loop @@ -216,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): config_path_or_distro_name: str, custom_provider_registry: ProviderRegistry | None = None, provider_data: dict[str, Any] | None = None, + skip_logger_removal: bool = False, ): super().__init__() # when using the library client, we should not log to console since many @@ -223,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") + if in_notebook(): + import nest_asyncio + + nest_asyncio.apply() + if not skip_logger_removal: + self._remove_root_logger_handlers() + if config_path_or_distro_name.endswith(".yaml"): config_path = Path(config_path_or_distro_name) if not config_path.exists(): @@ -239,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.provider_data = provider_data self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError + def _remove_root_logger_handlers(self): + """ + Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + """ + root_logger = logging.getLogger() + + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + async def initialize(self) -> bool: + """ + Initialize the async client. + + Returns: + bool: True if initialization was successful + """ + try: self.route_impls = None self.impls = await construct_stack(self.config, self.custom_provider_registry) diff --git a/llama_stack/core/routers/datasets.py b/llama_stack/core/routers/datasets.py index d7984f729..2f1d5f78e 100644 --- a/llama_stack/core/routers/datasets.py +++ b/llama_stack/core/routers/datasets.py @@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class DatasetIORouter(DatasetIO): diff --git a/llama_stack/core/routers/eval_scoring.py b/llama_stack/core/routers/eval_scoring.py index f7a17eecf..ffca81bf0 100644 --- a/llama_stack/core/routers/eval_scoring.py +++ b/llama_stack/core/routers/eval_scoring.py @@ -16,7 +16,7 @@ from llama_stack.apis.scoring import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class ScoringRouter(Scoring): diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 6a3f07247..4b66601bb 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.telemetry.tracing import get_current_span -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="core::routers") class InferenceRouter(Inference): diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index 738ecded3..9ba3327f1 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class SafetyRouter(Safety): diff --git a/llama_stack/core/routers/tool_runtime.py b/llama_stack/core/routers/tool_runtime.py index 5a40bc0c5..fd606f33b 100644 --- a/llama_stack/core/routers/tool_runtime.py +++ b/llama_stack/core/routers/tool_runtime.py @@ -22,7 +22,7 @@ from llama_stack.log import get_logger from ..routing_tables.toolgroups import ToolGroupsRoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class ToolRuntimeRouter(ToolRuntime): diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 3d0996c49..786b0e391 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class VectorIORouter(VectorIO): diff --git a/llama_stack/core/routing_tables/benchmarks.py b/llama_stack/core/routing_tables/benchmarks.py index 74bee8040..c875dee5b 100644 --- a/llama_stack/core/routing_tables/benchmarks.py +++ b/llama_stack/core/routing_tables/benchmarks.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 339ff6da4..e523746d8 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") def get_impl_api(p: Any) -> Api: diff --git a/llama_stack/core/routing_tables/datasets.py b/llama_stack/core/routing_tables/datasets.py index fc6a75df4..b129c9ec5 100644 --- a/llama_stack/core/routing_tables/datasets.py +++ b/llama_stack/core/routing_tables/datasets.py @@ -26,7 +26,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 34c431e00..b6141efa9 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -17,7 +17,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ModelsRoutingTable(CommonRoutingTableImpl, Models): diff --git a/llama_stack/core/routing_tables/scoring_functions.py b/llama_stack/core/routing_tables/scoring_functions.py index 5874ba941..71e5bed63 100644 --- a/llama_stack/core/routing_tables/scoring_functions.py +++ b/llama_stack/core/routing_tables/scoring_functions.py @@ -19,7 +19,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): diff --git a/llama_stack/core/routing_tables/shields.py b/llama_stack/core/routing_tables/shields.py index e08f35bfc..b1918d20a 100644 --- a/llama_stack/core/routing_tables/shields.py +++ b/llama_stack/core/routing_tables/shields.py @@ -15,7 +15,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): diff --git a/llama_stack/core/routing_tables/toolgroups.py b/llama_stack/core/routing_tables/toolgroups.py index 6910b3906..eeea406c1 100644 --- a/llama_stack/core/routing_tables/toolgroups.py +++ b/llama_stack/core/routing_tables/toolgroups.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index e8dc46997..00f71b4fe 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -30,7 +30,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): diff --git a/llama_stack/core/server/auth.py b/llama_stack/core/server/auth.py index e4fb4ff2b..c98d3bec0 100644 --- a/llama_stack/core/server/auth.py +++ b/llama_stack/core/server/auth.py @@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="auth") +logger = get_logger(name=__name__, category="core::auth") class AuthenticationMiddleware: diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index 73d5581c2..a8af6f75a 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -23,7 +23,7 @@ from llama_stack.core.datatypes import ( ) from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="auth") +logger = get_logger(name=__name__, category="core::auth") class AuthResponse(BaseModel): diff --git a/llama_stack/core/server/quota.py b/llama_stack/core/server/quota.py index 1cb850cde..693f224c3 100644 --- a/llama_stack/core/server/quota.py +++ b/llama_stack/core/server/quota.py @@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl -logger = get_logger(name=__name__, category="quota") +logger = get_logger(name=__name__, category="core::server") class QuotaMiddleware: diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index 3d94b6e81..d6dfc3435 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -28,6 +28,7 @@ from aiohttp import hdrs from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError @@ -40,6 +41,7 @@ from llama_stack.core.datatypes import ( AuthenticationRequiredError, LoggingConfig, StackRunConfig, + process_cors_config, ) from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.external import ExternalApiSpec, load_external_apis @@ -82,7 +84,7 @@ from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = get_logger(name=__name__, category="server") +logger = get_logger(name=__name__, category="core::server") def warn_with_traceback(message, category, filename, lineno, file=None, line=None): @@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None): config_contents = yaml.safe_load(fp) if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): logger_config = LoggingConfig(**cfg) - logger = get_logger(name=__name__, category="server", config=logger_config) + logger = get_logger(name=__name__, category="core::server", config=logger_config) if args.env: for env_pair in args.env: try: @@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None): window_seconds=window_seconds, ) + if config.server.cors: + logger.info("Enabling CORS") + cors_config = process_cors_config(config.server.cors) + if cors_config: + app.add_middleware(CORSMiddleware, **cors_config.model_dump()) + if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 87a3978c1..f734d0285 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -225,7 +225,10 @@ def replace_env_vars(config: Any, path: str = "") -> Any: try: result = re.sub(pattern, get_env_var, config) - return _convert_string_to_proper_type(result) + # Only apply type conversion if substitution actually happened + if result != config: + return _convert_string_to_proper_type(result) + return result except EnvVarError as e: raise EnvVarError(e.var_name, e.path) from None diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 4b60e1001..5f4abe9aa 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -16,7 +16,7 @@ from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig -logger = get_logger(__name__, category="core") +logger = get_logger(__name__, category="core::registry") class DistributionRegistry(Protocol): diff --git a/llama_stack/core/utils/config_resolution.py b/llama_stack/core/utils/config_resolution.py index 30cd71e15..182a571ee 100644 --- a/llama_stack/core/utils/config_resolution.py +++ b/llama_stack/core/utils/config_resolution.py @@ -10,7 +10,7 @@ from pathlib import Path from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="config_resolution") +logger = get_logger(name=__name__, category="core") DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions" diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index 0bf42e7ee..8e6c0bf67 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -34,7 +34,7 @@ distribution_spec: telemetry: - provider_type: inline::meta-reference post_training: - - provider_type: inline::huggingface + - provider_type: inline::torchtune-cpu eval: - provider_type: inline::meta-reference datasetio: diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 02a268462..7523df581 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -156,13 +156,10 @@ providers: sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} post_training: - - provider_id: huggingface - provider_type: inline::huggingface + - provider_id: torchtune-cpu + provider_type: inline::torchtune-cpu config: - checkpoint_format: huggingface - distributed_backend: null - device: cpu - dpo_output_dir: ~/.llama/distributions/ci-tests/dpo_output + checkpoint_format: meta eval: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/distributions/meta-reference-gpu/doc_template.md b/llama_stack/distributions/meta-reference-gpu/doc_template.md index ff45c3826..602d053c4 100644 --- a/llama_stack/distributions/meta-reference-gpu/doc_template.md +++ b/llama_stack/distributions/meta-reference-gpu/doc_template.md @@ -1,7 +1,7 @@ --- orphan: true --- -# Meta Reference Distribution +# Meta Reference GPU Distribution ```{toctree} :maxdepth: 2 @@ -29,7 +29,7 @@ The following environment variables can be configured: ## Prerequisite: Downloading Models -Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. +Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. ``` $ llama model list --downloaded diff --git a/llama_stack/distributions/starter-gpu/__init__.py b/llama_stack/distributions/starter-gpu/__init__.py new file mode 100644 index 000000000..e762f9b6e --- /dev/null +++ b/llama_stack/distributions/starter-gpu/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .starter_gpu import get_distribution_template # noqa: F401 diff --git a/llama_stack/distributions/starter-gpu/build.yaml b/llama_stack/distributions/starter-gpu/build.yaml new file mode 100644 index 000000000..ff7c58e6f --- /dev/null +++ b/llama_stack/distributions/starter-gpu/build.yaml @@ -0,0 +1,59 @@ +version: 2 +distribution_spec: + description: Quick start template for running Llama Stack with several popular providers. + This distribution is intended for GPU-enabled environments. + providers: + inference: + - provider_type: remote::cerebras + - provider_type: remote::ollama + - provider_type: remote::vllm + - provider_type: remote::tgi + - provider_type: remote::fireworks + - provider_type: remote::together + - provider_type: remote::bedrock + - provider_type: remote::nvidia + - provider_type: remote::openai + - provider_type: remote::anthropic + - provider_type: remote::gemini + - provider_type: remote::vertexai + - provider_type: remote::groq + - provider_type: remote::sambanova + - provider_type: inline::sentence-transformers + vector_io: + - provider_type: inline::faiss + - provider_type: inline::sqlite-vec + - provider_type: inline::milvus + - provider_type: remote::chromadb + - provider_type: remote::pgvector + files: + - provider_type: inline::localfs + safety: + - provider_type: inline::llama-guard + - provider_type: inline::code-scanner + agents: + - provider_type: inline::meta-reference + telemetry: + - provider_type: inline::meta-reference + post_training: + - provider_type: inline::huggingface-gpu + eval: + - provider_type: inline::meta-reference + datasetio: + - provider_type: remote::huggingface + - provider_type: inline::localfs + scoring: + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust + tool_runtime: + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol + batches: + - provider_type: inline::reference +image_type: venv +additional_pip_packages: +- aiosqlite +- asyncpg +- sqlalchemy[asyncio] diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml new file mode 100644 index 000000000..8aed61519 --- /dev/null +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -0,0 +1,241 @@ +version: 2 +image_name: starter-gpu +apis: +- agents +- batches +- datasetio +- eval +- files +- inference +- post_training +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: ${env.CEREBRAS_API_KEY:+cerebras} + provider_type: remote::cerebras + config: + base_url: https://api.cerebras.ai + api_key: ${env.CEREBRAS_API_KEY:=} + - provider_id: ${env.OLLAMA_URL:+ollama} + provider_type: remote::ollama + config: + url: ${env.OLLAMA_URL:=http://localhost:11434} + - provider_id: ${env.VLLM_URL:+vllm} + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: ${env.TGI_URL:+tgi} + provider_type: remote::tgi + config: + url: ${env.TGI_URL:=} + - provider_id: fireworks + provider_type: remote::fireworks + config: + url: https://api.fireworks.ai/inference/v1 + api_key: ${env.FIREWORKS_API_KEY:=} + - provider_id: together + provider_type: remote::together + config: + url: https://api.together.xyz/v1 + api_key: ${env.TOGETHER_API_KEY:=} + - provider_id: bedrock + provider_type: remote::bedrock + - provider_id: ${env.NVIDIA_API_KEY:+nvidia} + provider_type: remote::nvidia + config: + url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} + api_key: ${env.NVIDIA_API_KEY:=} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True} + - provider_id: openai + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY:=} + base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1} + - provider_id: anthropic + provider_type: remote::anthropic + config: + api_key: ${env.ANTHROPIC_API_KEY:=} + - provider_id: gemini + provider_type: remote::gemini + config: + api_key: ${env.GEMINI_API_KEY:=} + - provider_id: ${env.VERTEX_AI_PROJECT:+vertexai} + provider_type: remote::vertexai + config: + project: ${env.VERTEX_AI_PROJECT:=} + location: ${env.VERTEX_AI_LOCATION:=us-central1} + - provider_id: groq + provider_type: remote::groq + config: + url: https://api.groq.com + api_key: ${env.GROQ_API_KEY:=} + - provider_id: sambanova + provider_type: remote::sambanova + config: + url: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY:=} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db + - provider_id: sqlite-vec + provider_type: inline::sqlite-vec + config: + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db + - provider_id: ${env.MILVUS_URL:+milvus} + provider_type: inline::milvus + config: + db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db + - provider_id: ${env.CHROMADB_URL:+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db + - provider_id: ${env.PGVECTOR_DB:+pgvector} + provider_type: remote::pgvector + config: + host: ${env.PGVECTOR_HOST:=localhost} + port: ${env.PGVECTOR_PORT:=5432} + db: ${env.PGVECTOR_DB:=} + user: ${env.PGVECTOR_USER:=} + password: ${env.PGVECTOR_PASSWORD:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + - provider_id: code-scanner + provider_type: inline::code-scanner + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console,sqlite} + sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db + otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} + post_training: + - provider_id: huggingface-gpu + provider_type: inline::huggingface-gpu + config: + checkpoint_format: huggingface + distributed_backend: null + device: cpu + dpo_output_dir: ~/.llama/distributions/starter-gpu/dpo_output + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:=} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + batches: + - provider_id: reference + provider_type: inline::reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db +models: [] +shields: +- shield_id: llama-guard + provider_id: ${env.SAFETY_MODEL:+llama-guard} + provider_shield_id: ${env.SAFETY_MODEL:=} +- shield_id: code-scanner + provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} + provider_shield_id: ${env.CODE_SCANNER_MODEL:=} +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +server: + port: 8321 diff --git a/llama_stack/distributions/starter-gpu/starter_gpu.py b/llama_stack/distributions/starter-gpu/starter_gpu.py new file mode 100644 index 000000000..245334749 --- /dev/null +++ b/llama_stack/distributions/starter-gpu/starter_gpu.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.distributions.template import BuildProvider, DistributionTemplate + +from ..starter.starter import get_distribution_template as get_starter_distribution_template + + +def get_distribution_template() -> DistributionTemplate: + template = get_starter_distribution_template() + name = "starter-gpu" + template.name = name + template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments." + + template.providers["post_training"] = [ + BuildProvider(provider_type="inline::huggingface-gpu"), + ] + return template diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index 2ad12a165..e84e528da 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -1,6 +1,7 @@ version: 2 distribution_spec: - description: Quick start template for running Llama Stack with several popular providers + description: Quick start template for running Llama Stack with several popular providers. + This distribution is intended for CPU-only environments. providers: inference: - provider_type: remote::cerebras @@ -34,7 +35,7 @@ distribution_spec: telemetry: - provider_type: inline::meta-reference post_training: - - provider_type: inline::huggingface + - provider_type: inline::torchtune-cpu eval: - provider_type: inline::meta-reference datasetio: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 7ac4dc6b9..a3962b8aa 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -156,13 +156,10 @@ providers: sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} post_training: - - provider_id: huggingface - provider_type: inline::huggingface + - provider_id: torchtune-cpu + provider_type: inline::torchtune-cpu config: - checkpoint_format: huggingface - distributed_backend: null - device: cpu - dpo_output_dir: ~/.llama/distributions/starter/dpo_output + checkpoint_format: meta eval: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index cad3d72d9..a4bbc6371 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate: ], "agents": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")], - "post_training": [BuildProvider(provider_type="inline::huggingface")], + "post_training": [BuildProvider(provider_type="inline::torchtune-cpu")], "eval": [BuildProvider(provider_type="inline::meta-reference")], "datasetio": [ BuildProvider(provider_type="remote::huggingface"), @@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name=name, distro_type="self_hosted", - description="Quick start template for running Llama Stack with several popular providers", + description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.", container_image=None, template_path=None, providers=providers, diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py index 096156a5f..7b501eb0e 100644 --- a/llama_stack/models/llama/llama3/multimodal/model.py +++ b/llama_stack/models/llama/llama3/multimodal/model.py @@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple MP_SCALE = 8 -logger = get_logger(name=__name__, category="models") +logger = get_logger(name=__name__, category="models::llama") def reduce_from_tensor_model_parallel_region(input_): diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 574080184..d0e3e7671 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -11,7 +11,7 @@ from llama_stack.log import get_logger from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="models::llama") BUILTIN_TOOL_PATTERN = r'\b(?P\w+)\.call\(query="(?P[^"]*)"\)' CUSTOM_TOOL_CALL_PATTERN = re.compile(r"[^}]+)>(?P{.*?})") diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index 8220a9040..7557a8a64 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode from ..model import Transformer, TransformerBlock from ..moe import MoE -log = get_logger(name=__name__, category="models") +log = get_logger(name=__name__, category="models::llama") def swiglu_wrapper_no_reduce( diff --git a/llama_stack/models/llama/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py index 7fab2d3a6..0a205601f 100644 --- a/llama_stack/models/llama/quantize_impls.py +++ b/llama_stack/models/llama/quantize_impls.py @@ -9,7 +9,7 @@ import collections from llama_stack.log import get_logger -log = get_logger(name=__name__, category="llama") +log = get_logger(name=__name__, category="models::llama") try: import fbgemm_gpu.experimental.gen_ai # noqa: F401 diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 5f7c90879..fde38515b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" RAG_TOOL_GROUP = "builtin::rag" -logger = get_logger(name=__name__, category="agents") +logger = get_logger(name=__name__, category="agents::meta_reference") class ChatAgent(ShieldRunnerMixin): diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 5794ad2c0..8bdde86b0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig from .persistence import AgentInfo from .responses.openai_responses import OpenAIResponsesImpl -logger = get_logger(name=__name__, category="agents") +logger = get_logger(name=__name__, category="agents::meta_reference") class MetaReferenceAgentsImpl(Agents): diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index c19051f86..3b7b4729c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore -log = get_logger(name=__name__, category="agents") +log = get_logger(name=__name__, category="agents::meta_reference") class AgentSessionInfo(Session): diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index e528a4005..c632e61aa 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -41,7 +41,7 @@ from .utils import ( convert_response_text_to_chat_response_format, ) -logger = get_logger(name=__name__, category="responses") +logger = get_logger(name=__name__, category="openai::responses") class OpenAIResponsePreviousResponseWithInputItems(BaseModel): diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 0879e978a..3e69fa5cd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -47,7 +47,7 @@ from llama_stack.log import get_logger from .types import ChatCompletionContext, ChatCompletionResult from .utils import convert_chat_choice_to_response_message, is_function_tool_call -logger = get_logger(name=__name__, category="responses") +logger = get_logger(name=__name__, category="agents::meta_reference") class StreamingResponseOrchestrator: diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index 5b98b4f51..b028c018b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -38,7 +38,7 @@ from llama_stack.log import get_logger from .types import ChatCompletionContext, ToolExecutionResult -logger = get_logger(name=__name__, category="responses") +logger = get_logger(name=__name__, category="agents::meta_reference") class ToolExecutor: diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 486ac9351..7aaeb4cd5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -101,14 +101,22 @@ async def convert_response_input_to_chat_messages( """ messages: list[OpenAIMessageParam] = [] if isinstance(input, list): + # extract all OpenAIResponseInputFunctionToolCallOutput items + # so their corresponding OpenAIToolMessageParam instances can + # be added immediately following the corresponding + # OpenAIAssistantMessageParam + tool_call_results = {} for input_item in input: if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): - messages.append( - OpenAIToolMessageParam( - content=input_item.output, - tool_call_id=input_item.call_id, - ) + tool_call_results[input_item.call_id] = OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.call_id, ) + + for input_item in input: + if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): + # skip as these have been extracted and inserted in order + pass elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): tool_call = OpenAIChatCompletionToolCall( index=0, @@ -119,6 +127,9 @@ async def convert_response_input_to_chat_messages( ), ) messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + if input_item.call_id in tool_call_results: + messages.append(tool_call_results[input_item.call_id]) + del tool_call_results[input_item.call_id] elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall): tool_call = OpenAIChatCompletionToolCall( index=0, @@ -146,6 +157,10 @@ async def convert_response_input_to_chat_messages( f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" ) messages.append(message_type(content=content)) + if len(tool_call_results): + raise ValueError( + f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call" + ) else: messages.append(OpenAIUserMessageParam(content=input)) return messages diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index b8a5d8a95..8f3ecf5c9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry import tracing -log = get_logger(name=__name__, category="agents") +log = get_logger(name=__name__, category="agents::meta_reference") class SafetyException(Exception): # noqa: N818 diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 1ff554e70..26f0ad15a 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import asyncio +import hashlib import itertools import json import time @@ -136,28 +137,45 @@ class ReferenceBatchesImpl(Batches): endpoint: str, completion_window: Literal["24h"], metadata: dict[str, str] | None = None, + idempotency_key: str | None = None, ) -> BatchObject: """ Create a new batch for processing multiple API requests. - Error handling by levels - - 0. Input param handling, results in 40x errors before processing, e.g. - - Wrong completion_window - - Invalid metadata types - - Unknown endpoint - -> no batch created - 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. - - input_file_id missing - - invalid json in file - - missing custom_id, method, url, body - - invalid model - - streaming - -> batch created, validation sends to failed status - 2. Processing errors, result in error_file_id entries, e.g. - - Any error returned from inference endpoint - -> batch created, goes to completed status + This implementation provides optional idempotency: when an idempotency key + (idempotency_key) is provided, a deterministic ID is generated based on the input + parameters. If a batch with the same parameters already exists, it will be + returned instead of creating a duplicate. Without an idempotency key, + each request creates a new batch with a unique ID. + + Args: + input_file_id: The ID of an uploaded file containing requests for the batch. + endpoint: The endpoint to be used for all requests in the batch. + completion_window: The time window within which the batch should be processed. + metadata: Optional metadata for the batch. + idempotency_key: Optional idempotency key for enabling idempotent behavior. + + Returns: + The created or existing batch object. """ + # Error handling by levels - + # 0. Input param handling, results in 40x errors before processing, e.g. + # - Wrong completion_window + # - Invalid metadata types + # - Unknown endpoint + # -> no batch created + # 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. + # - input_file_id missing + # - invalid json in file + # - missing custom_id, method, url, body + # - invalid model + # - streaming + # -> batch created, validation sends to failed status + # 2. Processing errors, result in error_file_id entries, e.g. + # - Any error returned from inference endpoint + # -> batch created, goes to completed status + # TODO: set expiration time for garbage collection if endpoint not in ["/v1/chat/completions"]: @@ -171,6 +189,35 @@ class ReferenceBatchesImpl(Batches): ) batch_id = f"batch_{uuid.uuid4().hex[:16]}" + + # For idempotent requests, use the idempotency key for the batch ID + # This ensures the same key always maps to the same batch ID, + # allowing us to detect parameter conflicts + if idempotency_key is not None: + hash_input = idempotency_key.encode("utf-8") + hash_digest = hashlib.sha256(hash_input).hexdigest()[:24] + batch_id = f"batch_{hash_digest}" + + try: + existing_batch = await self.retrieve_batch(batch_id) + + if ( + existing_batch.input_file_id != input_file_id + or existing_batch.endpoint != endpoint + or existing_batch.completion_window != completion_window + or existing_batch.metadata != metadata + ): + raise ConflictError( + f"Idempotency key '{idempotency_key}' was previously used with different parameters. " + "Either use a new idempotency key or ensure all parameters match the original request." + ) + + logger.info(f"Returning existing batch with ID: {batch_id}") + return existing_batch + except ResourceNotFoundError: + # Batch doesn't exist, continue with creation + pass + current_time = int(time.time()) batch = BatchObject( @@ -185,6 +232,7 @@ class ReferenceBatchesImpl(Batches): ) await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) + logger.info(f"Created new batch with ID: {batch_id}") if self.process_batches: task = asyncio.create_task(self._process_batch(batch_id)) diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 600a5bd37..34665b63e 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator from llama_stack.apis.inference import ( CompletionResponse, InferenceProvider, - InterleavedContent, LogProbConfig, Message, ResponseFormat, @@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl( tool_config: ToolConfig | None = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") - - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Sentence Transformers") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers") diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 30710ec2a..9224c3792 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import datetime import threading from typing import Any @@ -145,11 +146,41 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): metric_name: str, start_time: int, end_time: int | None = None, - granularity: str | None = "1d", + granularity: str | None = None, query_type: MetricQueryType = MetricQueryType.RANGE, label_matchers: list[MetricLabelMatcher] | None = None, ) -> QueryMetricsResponse: - raise NotImplementedError("Querying metrics is not implemented") + """Query metrics from the telemetry store. + + Args: + metric_name: The name of the metric to query (e.g., "prompt_tokens") + start_time: Start time as Unix timestamp + end_time: End time as Unix timestamp (defaults to now if None) + granularity: Time granularity for aggregation + query_type: Type of query (RANGE or INSTANT) + label_matchers: Label filters to apply + + Returns: + QueryMetricsResponse with metric time series data + """ + # Convert timestamps to datetime objects + start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC) + end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None + + # Use SQLite trace store if available + if hasattr(self, "trace_store") and self.trace_store: + return await self.trace_store.query_metrics( + metric_name=metric_name, + start_time=start_dt, + end_time=end_dt, + granularity=granularity, + query_type=query_type, + label_matchers=label_matchers, + ) + else: + raise ValueError( + f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks" + ) def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: with self._lock: diff --git a/llama_stack/providers/registry/files.py b/llama_stack/providers/registry/files.py index e894debaf..ebe90310c 100644 --- a/llama_stack/providers/registry/files.py +++ b/llama_stack/providers/registry/files.py @@ -5,9 +5,11 @@ # the root directory of this source tree. from llama_stack.providers.datatypes import ( + AdapterSpec, Api, InlineProviderSpec, ProviderSpec, + remote_provider_spec, ) from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages @@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", description="Local filesystem-based file storage provider for managing files and documents locally.", ), + remote_provider_spec( + api=Api.files, + adapter=AdapterSpec( + adapter_type="s3", + pip_packages=["boto3"] + sql_store_pip_packages, + module="llama_stack.providers.remote.files.s3", + config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", + description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", + ), + ), ] diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 1801cdcad..82b771a28 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -40,8 +40,9 @@ def available_providers() -> list[ProviderSpec]: InlineProviderSpec( api=Api.inference, provider_type="inline::sentence-transformers", + # CrossEncoder depends on torchao.quantization pip_packages=[ - "torch torchvision --index-url https://download.pytorch.org/whl/cpu", + "torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu", "sentence-transformers --no-deps", ], module="llama_stack.providers.inline.inference.sentence_transformers", diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index ffd64ef7c..67238e3fc 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -5,27 +5,50 @@ # the root directory of this source tree. +from typing import cast + from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec +# We provide two versions of these providers so that distributions can package the appropriate version of torch. +# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. +torchtune_def = dict( + api=Api.post_training, + pip_packages=["numpy"], + module="llama_stack.providers.inline.post_training.torchtune", + config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.", +) + def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( - api=Api.post_training, - provider_type="inline::torchtune", - pip_packages=["torch", "torchtune==0.5.0", "torchao==0.8.0", "numpy"], - module="llama_stack.providers.inline.post_training.torchtune", - config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", - api_dependencies=[ - Api.datasetio, - Api.datasets, - ], - description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.", + **{ # type: ignore + **torchtune_def, + "provider_type": "inline::torchtune-cpu", + "pip_packages": ( + cast(list[str], torchtune_def["pip_packages"]) + + ["torch torchtune>=0.5.0 torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu"] + ), + }, + ), + InlineProviderSpec( + **{ # type: ignore + **torchtune_def, + "provider_type": "inline::torchtune-gpu", + "pip_packages": ( + cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune>=0.5.0 torchao>=0.12.0"] + ), + }, ), InlineProviderSpec( api=Api.post_training, - provider_type="inline::huggingface", - pip_packages=["torch", "trl", "transformers", "peft", "datasets"], + provider_type="inline::huggingface-gpu", + pip_packages=["trl", "transformers", "peft", "datasets", "torch"], module="llama_stack.providers.inline.post_training.huggingface", config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", api_dependencies=[ diff --git a/llama_stack/providers/remote/files/s3/README.md b/llama_stack/providers/remote/files/s3/README.md new file mode 100644 index 000000000..0f33122c7 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/README.md @@ -0,0 +1,237 @@ +# S3 Files Provider + +A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence. + +## Features + +- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage +- **Metadata Management**: Uses SQL database for efficient file metadata queries +- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints +- **Flexible Authentication**: Support for IAM roles and access keys +- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services + +## Configuration + +### Basic Configuration + +```yaml +api: files +provider_type: remote::s3 +config: + bucket_name: my-llama-stack-files + region: us-east-1 + metadata_store: + type: sqlite + db_path: ./s3_files_metadata.db +``` + +### Advanced Configuration + +```yaml +api: files +provider_type: remote::s3 +config: + bucket_name: my-llama-stack-files + region: us-east-1 + aws_access_key_id: YOUR_ACCESS_KEY + aws_secret_access_key: YOUR_SECRET_KEY + endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints + metadata_store: + type: sqlite + db_path: ./s3_files_metadata.db +``` + +### Environment Variables + +The configuration supports environment variable substitution: + +```yaml +config: + bucket_name: "${env.S3_BUCKET_NAME}" + region: "${env.AWS_REGION:=us-east-1}" + aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}" + aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}" + endpoint_url: "${env.S3_ENDPOINT_URL:=}" +``` + +Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique. + +## Authentication + +### IAM Roles (Recommended) + +For production deployments, use IAM roles: + +```yaml +config: + bucket_name: my-bucket + region: us-east-1 + # No credentials needed - will use IAM role +``` + +### Access Keys + +For development or specific use cases: + +```yaml +config: + bucket_name: my-bucket + region: us-east-1 + aws_access_key_id: AKIAIOSFODNN7EXAMPLE + aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY +``` + +## S3 Bucket Setup + +### Required Permissions + +The S3 provider requires the following permissions: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:ListBucket" + ], + "Resource": [ + "arn:aws:s3:::your-bucket-name", + "arn:aws:s3:::your-bucket-name/*" + ] + } + ] +} +``` + +### Automatic Bucket Creation + +By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration: + +```yaml +config: + bucket_name: my-bucket + auto_create_bucket: true # Will create bucket if it doesn't exist + region: us-east-1 +``` + +**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:ListBucket", + "s3:CreateBucket" + ], + "Resource": [ + "arn:aws:s3:::your-bucket-name", + "arn:aws:s3:::your-bucket-name/*" + ] + } + ] +} +``` + +### Bucket Policy (Optional) + +For additional security, you can add a bucket policy: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "LlamaStackAccess", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole" + }, + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject" + ], + "Resource": "arn:aws:s3:::your-bucket-name/*" + }, + { + "Sid": "LlamaStackBucketAccess", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole" + }, + "Action": [ + "s3:ListBucket" + ], + "Resource": "arn:aws:s3:::your-bucket-name" + } + ] +} +``` + +## Features + +### Metadata Persistence + +File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes: + +- File ID +- Original filename +- Purpose (assistants, batch, etc.) +- File size in bytes +- Created and expiration timestamps + +### TTL and Cleanup + +Files currently have a fixed long expiration time (100 years). + +## Development and Testing + +### Using MinIO + +For self-hosted S3-compatible storage: + +```yaml +config: + bucket_name: test-bucket + region: us-east-1 + endpoint_url: http://localhost:9000 + aws_access_key_id: minioadmin + aws_secret_access_key: minioadmin +``` + +## Monitoring and Logging + +The provider logs important operations and errors. For production deployments, consider: + +- CloudWatch monitoring for S3 operations +- Custom metrics for file upload/download rates +- Error rate monitoring +- Performance metrics tracking + +## Error Handling + +The provider handles various error scenarios: + +- S3 connectivity issues +- Bucket access permissions +- File not found errors +- Metadata consistency checks + +## Known Limitations + +- Fixed long TTL (100 years) instead of configurable expiration +- No server-side encryption enabled by default +- No support for AWS session tokens +- No S3 key prefix organization support +- No multipart upload support (all files uploaded as single objects) diff --git a/llama_stack/providers/remote/files/s3/__init__.py b/llama_stack/providers/remote/files/s3/__init__.py new file mode 100644 index 000000000..3f5dfc88a --- /dev/null +++ b/llama_stack/providers/remote/files/s3/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.core.datatypes import Api + +from .config import S3FilesImplConfig + + +async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]): + from .files import S3FilesImpl + + # TODO: authorization policies and user separation + impl = S3FilesImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/files/s3/config.py b/llama_stack/providers/remote/files/s3/config.py new file mode 100644 index 000000000..da20d8668 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/config.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig + + +class S3FilesImplConfig(BaseModel): + """Configuration for S3-based files provider.""" + + bucket_name: str = Field(description="S3 bucket name to store files") + region: str = Field(default="us-east-1", description="AWS region where the bucket is located") + aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)") + aws_secret_access_key: str | None = Field( + default=None, description="AWS secret access key (optional if using IAM roles)" + ) + endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)") + auto_create_bucket: bool = Field( + default=False, description="Automatically create the S3 bucket if it doesn't exist" + ) + metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata") + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: + return { + "bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique + "region": "${env.AWS_REGION:=us-east-1}", + "aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}", + "aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}", + "endpoint_url": "${env.S3_ENDPOINT_URL:=}", + "auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}", + "metadata_store": SqliteSqlStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="s3_files_metadata.db", + ), + } diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py new file mode 100644 index 000000000..52e0cbbf4 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/files.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +import uuid +from typing import Annotated + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError +from fastapi import File, Form, Response, UploadFile + +from llama_stack.apis.common.errors import ResourceNotFoundError +from llama_stack.apis.common.responses import Order +from llama_stack.apis.files import ( + Files, + ListOpenAIFileResponse, + OpenAIFileDeleteResponse, + OpenAIFileObject, + OpenAIFilePurpose, +) +from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType +from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl + +from .config import S3FilesImplConfig + +# TODO: provider data for S3 credentials + + +def _create_s3_client(config: S3FilesImplConfig) -> boto3.client: + try: + s3_config = { + "region_name": config.region, + } + + # endpoint URL if specified (for MinIO, LocalStack, etc.) + if config.endpoint_url: + s3_config["endpoint_url"] = config.endpoint_url + + if config.aws_access_key_id and config.aws_secret_access_key: + s3_config.update( + { + "aws_access_key_id": config.aws_access_key_id, + "aws_secret_access_key": config.aws_secret_access_key, + } + ) + + return boto3.client("s3", **s3_config) + + except (BotoCoreError, NoCredentialsError) as e: + raise RuntimeError(f"Failed to initialize S3 client: {e}") from e + + +async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None: + try: + client.head_bucket(Bucket=config.bucket_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code == "404": + if not config.auto_create_bucket: + raise RuntimeError( + f"S3 bucket '{config.bucket_name}' does not exist. " + f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration." + ) from e + try: + # For us-east-1, we can't specify LocationConstraint + if config.region == "us-east-1": + client.create_bucket(Bucket=config.bucket_name) + else: + client.create_bucket( + Bucket=config.bucket_name, + CreateBucketConfiguration={"LocationConstraint": config.region}, + ) + except ClientError as create_error: + raise RuntimeError( + f"Failed to create S3 bucket '{config.bucket_name}': {create_error}" + ) from create_error + elif error_code == "403": + raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e + else: + raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e + + +class S3FilesImpl(Files): + """S3-based implementation of the Files API.""" + + # TODO: implement expiration, for now a silly offset + _SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60 + + def __init__(self, config: S3FilesImplConfig) -> None: + self._config = config + self._client: boto3.client | None = None + self._sql_store: SqlStore | None = None + + async def initialize(self) -> None: + self._client = _create_s3_client(self._config) + await _create_bucket_if_not_exists(self._client, self._config) + + self._sql_store = sqlstore_impl(self._config.metadata_store) + await self._sql_store.create_table( + "openai_files", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "filename": ColumnType.STRING, + "purpose": ColumnType.STRING, + "bytes": ColumnType.INTEGER, + "created_at": ColumnType.INTEGER, + "expires_at": ColumnType.INTEGER, + # TODO: add s3_etag field for integrity checking + }, + ) + + async def shutdown(self) -> None: + pass + + @property + def client(self) -> boto3.client: + assert self._client is not None, "Provider not initialized" + return self._client + + @property + def sql_store(self) -> SqlStore: + assert self._sql_store is not None, "Provider not initialized" + return self._sql_store + + async def openai_upload_file( + self, + file: Annotated[UploadFile, File()], + purpose: Annotated[OpenAIFilePurpose, Form()], + ) -> OpenAIFileObject: + file_id = f"file-{uuid.uuid4().hex}" + + filename = getattr(file, "filename", None) or "uploaded_file" + + created_at = int(time.time()) + expires_at = created_at + self._SILLY_EXPIRATION_OFFSET + content = await file.read() + file_size = len(content) + + await self.sql_store.insert( + "openai_files", + { + "id": file_id, + "filename": filename, + "purpose": purpose.value, + "bytes": file_size, + "created_at": created_at, + "expires_at": expires_at, + }, + ) + + try: + self.client.put_object( + Bucket=self._config.bucket_name, + Key=file_id, + Body=content, + # TODO: enable server-side encryption + ) + except ClientError as e: + await self.sql_store.delete("openai_files", where={"id": file_id}) + + raise RuntimeError(f"Failed to upload file to S3: {e}") from e + + return OpenAIFileObject( + id=file_id, + filename=filename, + purpose=purpose, + bytes=file_size, + created_at=created_at, + expires_at=expires_at, + ) + + async def openai_list_files( + self, + after: str | None = None, + limit: int | None = 10000, + order: Order | None = Order.desc, + purpose: OpenAIFilePurpose | None = None, + ) -> ListOpenAIFileResponse: + # this purely defensive. it should not happen because the router also default to Order.desc. + if not order: + order = Order.desc + + where_conditions = {} + if purpose: + where_conditions["purpose"] = purpose.value + + paginated_result = await self.sql_store.fetch_all( + table="openai_files", + where=where_conditions if where_conditions else None, + order_by=[("created_at", order.value)], + cursor=("id", after) if after else None, + limit=limit, + ) + + files = [ + OpenAIFileObject( + id=row["id"], + filename=row["filename"], + purpose=OpenAIFilePurpose(row["purpose"]), + bytes=row["bytes"], + created_at=row["created_at"], + expires_at=row["expires_at"], + ) + for row in paginated_result.data + ] + + return ListOpenAIFileResponse( + data=files, + has_more=paginated_result.has_more, + # empty string or None? spec says str, ref impl returns str | None, we go with spec + first_id=files[0].id if files else "", + last_id=files[-1].id if files else "", + ) + + async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + return OpenAIFileObject( + id=row["id"], + filename=row["filename"], + purpose=OpenAIFilePurpose(row["purpose"]), + bytes=row["bytes"], + created_at=row["created_at"], + expires_at=row["expires_at"], + ) + + async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + try: + self.client.delete_object( + Bucket=self._config.bucket_name, + Key=row["id"], + ) + except ClientError as e: + if e.response["Error"]["Code"] != "NoSuchKey": + raise RuntimeError(f"Failed to delete file from S3: {e}") from e + + await self.sql_store.delete("openai_files", where={"id": file_id}) + + return OpenAIFileDeleteResponse(id=file_id, deleted=True) + + async def openai_retrieve_file_content(self, file_id: str) -> Response: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + try: + response = self.client.get_object( + Bucket=self._config.bucket_name, + Key=row["id"], + ) + # TODO: can we stream this instead of loading it into memory + content = response["Body"].read() + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + await self.sql_store.delete("openai_files", where={"id": file_id}) + raise ResourceNotFoundError(file_id, "File", "files.list()") from e + raise RuntimeError(f"Failed to download file from S3: {e}") from e + + return Response( + content=content, + media_type="application/octet-stream", + headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'}, + ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index bd86f7238..e907e8ec6 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::fireworks") class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index cfcfcbf90..f2069b5e5 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -10,7 +10,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::llama_openai_compat") class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): diff --git a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md index 35d26fd0b..d96b29fef 100644 --- a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md +++ b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -41,6 +41,11 @@ client.initialize() ### Create Completion +> Note on Completion API +> +> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does. + + ```python response = client.inference.completion( model_id="meta-llama/Llama-3.1-8B-Instruct", @@ -76,6 +81,73 @@ response = client.inference.chat_completion( print(f"Response: {response.completion_message.content}") ``` +### Tool Calling Example ### +```python +from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + +tool_definition = ToolDefinition( + tool_name="get_weather", + description="Get current weather information for a location", + parameters={ + "location": ToolParamDefinition( + param_type="string", + description="The city and state, e.g. San Francisco, CA", + required=True, + ), + "unit": ToolParamDefinition( + param_type="string", + description="Temperature unit (celsius or fahrenheit)", + required=False, + default="celsius", + ), + }, +) + +tool_response = client.inference.chat_completion( + model_id="meta-llama/Llama-3.1-8B-Instruct", + messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}], + tools=[tool_definition], +) + +print(f"Tool Response: {tool_response.completion_message.content}") +if tool_response.completion_message.tool_calls: + for tool_call in tool_response.completion_message.tool_calls: + print(f"Tool Called: {tool_call.tool_name}") + print(f"Arguments: {tool_call.arguments}") +``` + +### Structured Output Example +```python +from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType + +person_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "occupation": {"type": "string"}, + }, + "required": ["name", "age", "occupation"], +} + +response_format = JsonSchemaResponseFormat( + type=ResponseFormatType.json_schema, json_schema=person_schema +) + +structured_response = client.inference.chat_completion( + model_id="meta-llama/Llama-3.1-8B-Instruct", + messages=[ + { + "role": "user", + "content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ", + } + ], + response_format=response_format, +) + +print(f"Structured Response: {structured_response.completion_message.content}") +``` + ### Create Embeddings > Note on OpenAI embeddings compatibility > diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7052cfb57..a5475bc92 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,7 +7,7 @@ import warnings from collections.abc import AsyncIterator -from openai import NOT_GIVEN, APIConnectionError, BadRequestError +from openai import NOT_GIVEN, APIConnectionError from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -57,7 +57,7 @@ from .openai_utils import ( ) from .utils import _is_nvidia_hosted -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::nvidia") class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): @@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): } extra_body["input_type"] = task_type_options[task_type] - try: - response = await self.client.embeddings.create( - model=provider_model_id, - input=input, - extra_body=extra_body, - ) - except BadRequestError as e: - raise ValueError(f"Failed to get embeddings: {e}") from e - + response = await self.client.embeddings.create( + model=provider_model_id, + input=input, + extra_body=extra_body, + ) # # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...) # -> diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py index 790bbafd1..b8431e859 100644 --- a/llama_stack/providers/remote/inference/nvidia/utils.py +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -10,7 +10,7 @@ from llama_stack.log import get_logger from . import NVIDIAConfig -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::nvidia") def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index a93421536..fcaf5ee92 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::ollama") class OllamaInferenceAdapter( @@ -619,28 +619,6 @@ class OllamaInferenceAdapter( response.id = id return response - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Ollama") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Ollama") - async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 1c72fa0bc..0f73c9321 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import OpenAIConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::openai") # diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 9da961438..97c72d14c 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference::tgi") def build_hf_repo_model_entries(): diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index a06e4173b..54c76607f 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import TogetherImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::together") class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index ac626874c..9e9a80ca5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference::vllm") def build_hf_repo_model_entries(): @@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): user=user, ) return await self.client.chat.completions.create(**params) # type: ignore - - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Ollama") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Ollama") diff --git a/llama_stack/providers/remote/post_training/nvidia/utils.py b/llama_stack/providers/remote/post_training/nvidia/utils.py index 9a6c3b53c..162951ff3 100644 --- a/llama_stack/providers/remote/post_training/nvidia/utils.py +++ b/llama_stack/providers/remote/post_training/nvidia/utils.py @@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa from .config import NvidiaPostTrainingConfig -logger = get_logger(name=__name__, category="integration") +logger = get_logger(name=__name__, category="post_training::nvidia") def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None: diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 1ca87ae3d..8855e02a4 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client from .config import BedrockSafetyConfig -logger = get_logger(name=__name__, category="safety") +logger = get_logger(name=__name__, category="safety::bedrock") class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 0d8d8ba7a..65f901da2 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -9,7 +9,7 @@ from typing import Any import requests from llama_stack.apis.inference import Message -from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel +from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_ from .config import NVIDIASafetyConfig -logger = get_logger(name=__name__, category="safety") +logger = get_logger(name=__name__, category="safety::nvidia") class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): @@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): self.shield = NeMoGuardrails(self.config, shield.shield_id) return await self.shield.run(messages) + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation") + class NeMoGuardrails: """ diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 676ee7185..2beb5e0ea 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_ from .config import SambaNovaSafetyConfig -logger = get_logger(name=__name__, category="safety") +logger = get_logger(name=__name__, category="safety::sambanova") CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 0047e6055..a9ec644ef 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig -log = get_logger(name=__name__, category="vector_io") +log = get_logger(name=__name__, category="vector_io::chroma") ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 034ec331c..e07e8ff12 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig -logger = get_logger(name=__name__, category="vector_io") +logger = get_logger(name=__name__, category="vector_io::milvus") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::" diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 66c64b287..1c140e782 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -39,7 +39,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryA from .config import PGVectorVectorIOConfig -log = get_logger(name=__name__, category="vector_io") +log = get_logger(name=__name__, category="vector_io::pgvector") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::" diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 8499ff997..0a0faa23a 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig -log = get_logger(name=__name__, category="vector_io") +log = get_logger(name=__name__, category="vector_io::qdrant") CHUNK_ID_KEY = "_chunk_id" # KV store prefixes for vector databases diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index ddf95317b..59b6bf124 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import WeaviateVectorIOConfig -log = get_logger(name=__name__, category="vector_io") +log = get_logger(name=__name__, category="vector_io::weaviate") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::" diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 05886cdc8..65ba2854b 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con EMBEDDING_MODELS = {} -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="providers::utils") class SentenceTransformerEmbeddingMixin: diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index da2e634f6..9bd43e4c9 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="providers::utils") class LiteLLMOpenAIMixin( @@ -429,28 +429,6 @@ class LiteLLMOpenAIMixin( ) return await litellm.acompletion(**params) - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for OpenAI Compat") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat") - async def check_model_availability(self, model: str) -> bool: """ Check if a specific model is available via LiteLLM for the current diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index ddb3bda8c..44add8f9e 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ) -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="providers::utils") class RemoteInferenceProviderConfig(BaseModel): diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index eb32d2de9..55c2ac0ad 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( decode_assistant_message, ) -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="providers::utils") class OpenAICompatCompletionChoiceDelta(BaseModel): diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 72286dffb..f60deee6e 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -25,7 +25,7 @@ from llama_stack.apis.inference import ( from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="providers::utils") class OpenAIMixin(ABC): diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index bb9a91b97..a93326e41 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.utils.inference import supported_inference_models -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="providers::utils") class ChatCompletionRequestWithRawContent(ChatCompletionRequest): diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index af52f3708..bab87a4aa 100644 --- a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -13,7 +13,7 @@ from llama_stack.providers.utils.kvstore import KVStore from ..config import MongoDBKVStoreConfig -log = get_logger(name=__name__, category="kvstore") +log = get_logger(name=__name__, category="providers::utils") class MongoDBKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py index 021e90774..56d6dbb48 100644 --- a/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from ..api import KVStore from ..config import PostgresKVStoreConfig -log = get_logger(name=__name__, category="kvstore") +log = get_logger(name=__name__, category="providers::utils") class PostgresKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 0775b31d1..3acdcf293 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import ( make_overlapped_chunks, ) -logger = get_logger(name=__name__, category="memory") +logger = get_logger(name=__name__, category="providers::utils") # Constants for OpenAI vector stores CHUNK_MULTIPLIER = 5 diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index b5d82432d..b74080384 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id -log = get_logger(name=__name__, category="memory") +log = get_logger(name=__name__, category="providers::utils") class ChunkForDeletion(BaseModel): diff --git a/llama_stack/providers/utils/scheduler.py b/llama_stack/providers/utils/scheduler.py index 65c3d2898..146591b2f 100644 --- a/llama_stack/providers/utils/scheduler.py +++ b/llama_stack/providers/utils/scheduler.py @@ -17,7 +17,7 @@ from pydantic import BaseModel from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="scheduler") +logger = get_logger(name=__name__, category="providers::utils") # TODO: revisit the list of possible statuses when defining a more coherent diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index ccc835768..867ba2f55 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -17,7 +17,7 @@ from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore from .sqlstore import SqlStoreType -logger = get_logger(name=__name__, category="authorized_sqlstore") +logger = get_logger(name=__name__, category="providers::utils") # Hardcoded copy of the default policy that our SQL filtering implements # WARNING: If default_policy() changes, this constant must be updated accordingly diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 7fa0cc755..f75c35314 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -30,7 +30,7 @@ from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, SqlStore from .sqlstore import SqlAlchemySqlStoreConfig -logger = get_logger(name=__name__, category="sqlstore") +logger = get_logger(name=__name__, category="providers::utils") TYPE_MAPPING: dict[ColumnType, Any] = { ColumnType.INTEGER: Integer, diff --git a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py index 8dd6061a6..71480364c 100644 --- a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py +++ b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py @@ -5,12 +5,23 @@ # the root directory of this source tree. import json -from datetime import datetime +from datetime import UTC, datetime from typing import Protocol import aiosqlite -from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithStatus, Trace +from llama_stack.apis.telemetry import ( + MetricDataPoint, + MetricLabel, + MetricLabelMatcher, + MetricQueryType, + MetricSeries, + QueryCondition, + QueryMetricsResponse, + Span, + SpanWithStatus, + Trace, +) class TraceStore(Protocol): @@ -29,11 +40,192 @@ class TraceStore(Protocol): max_depth: int | None = None, ) -> dict[str, SpanWithStatus]: ... + async def query_metrics( + self, + metric_name: str, + start_time: datetime, + end_time: datetime | None = None, + granularity: str | None = "1d", + query_type: MetricQueryType = MetricQueryType.RANGE, + label_matchers: list[MetricLabelMatcher] | None = None, + ) -> QueryMetricsResponse: ... + class SQLiteTraceStore(TraceStore): def __init__(self, conn_string: str): self.conn_string = conn_string + async def query_metrics( + self, + metric_name: str, + start_time: datetime, + end_time: datetime | None = None, + granularity: str | None = None, + query_type: MetricQueryType = MetricQueryType.RANGE, + label_matchers: list[MetricLabelMatcher] | None = None, + ) -> QueryMetricsResponse: + if end_time is None: + end_time = datetime.now(UTC) + + # Build base query + if query_type == MetricQueryType.INSTANT: + query = """ + SELECT + se.name, + SUM(CAST(json_extract(se.attributes, '$.value') AS REAL)) as value, + json_extract(se.attributes, '$.unit') as unit, + se.attributes + FROM span_events se + WHERE se.name = ? + AND se.timestamp BETWEEN ? AND ? + """ + else: + if granularity: + time_format = self._get_time_format_for_granularity(granularity) + query = f""" + SELECT + se.name, + SUM(CAST(json_extract(se.attributes, '$.value') AS REAL)) as value, + json_extract(se.attributes, '$.unit') as unit, + se.attributes, + strftime('{time_format}', se.timestamp) as bucket_start + FROM span_events se + WHERE se.name = ? + AND se.timestamp BETWEEN ? AND ? + """ + else: + query = """ + SELECT + se.name, + json_extract(se.attributes, '$.value') as value, + json_extract(se.attributes, '$.unit') as unit, + se.attributes, + se.timestamp + FROM span_events se + WHERE se.name = ? + AND se.timestamp BETWEEN ? AND ? + """ + + params = [f"metric.{metric_name}", start_time.isoformat(), end_time.isoformat()] + + # Labels that will be attached to the MetricSeries (preserve matcher labels) + all_labels: list[MetricLabel] = [] + matcher_label_names = set() + if label_matchers: + for matcher in label_matchers: + json_path = f"$.{matcher.name}" + if matcher.operator == "=": + query += f" AND json_extract(se.attributes, '{json_path}') = ?" + params.append(matcher.value) + elif matcher.operator == "!=": + query += f" AND json_extract(se.attributes, '{json_path}') != ?" + params.append(matcher.value) + elif matcher.operator == "=~": + query += f" AND json_extract(se.attributes, '{json_path}') LIKE ?" + params.append(f"%{matcher.value}%") + elif matcher.operator == "!~": + query += f" AND json_extract(se.attributes, '{json_path}') NOT LIKE ?" + params.append(f"%{matcher.value}%") + # Preserve filter context in output + all_labels.append(MetricLabel(name=matcher.name, value=str(matcher.value))) + matcher_label_names.add(matcher.name) + + # GROUP BY / ORDER BY logic + if query_type == MetricQueryType.RANGE and granularity: + group_time_format = self._get_time_format_for_granularity(granularity) + query += f" GROUP BY strftime('{group_time_format}', se.timestamp), json_extract(se.attributes, '$.unit')" + query += " ORDER BY bucket_start" + elif query_type == MetricQueryType.INSTANT: + query += " GROUP BY json_extract(se.attributes, '$.unit')" + else: + query += " ORDER BY se.timestamp" + + # Execute query + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + + if not rows: + return QueryMetricsResponse(data=[]) + + data_points = [] + # We want to add attribute labels, but only those not already present as matcher labels. + attr_label_names = set() + for row in rows: + # Parse JSON attributes safely, if there are no attributes (weird), just don't add the labels to the result. + try: + attributes = json.loads(row["attributes"] or "{}") + except (TypeError, json.JSONDecodeError): + attributes = {} + + value = row["value"] + unit = row["unit"] or "" + + # Add labels from attributes without duplicating matcher labels, if we don't do this, there will be a lot of duplicate label in the result. + for k, v in attributes.items(): + if k not in ["value", "unit"] and k not in matcher_label_names and k not in attr_label_names: + all_labels.append(MetricLabel(name=k, value=str(v))) + attr_label_names.add(k) + + # Determine timestamp + if query_type == MetricQueryType.RANGE and granularity: + try: + bucket_start_raw = row["bucket_start"] + except KeyError as e: + raise ValueError( + "DB did not have a bucket_start time in row when using granularity, this indicates improper formatting" + ) from e + # this value could also be there, but be NULL, I think. + if bucket_start_raw is None: + raise ValueError("bucket_start is None check time format and data") + bucket_start = datetime.fromisoformat(bucket_start_raw) + timestamp = int(bucket_start.timestamp()) + elif query_type == MetricQueryType.INSTANT: + timestamp = int(datetime.now(UTC).timestamp()) + else: + try: + timestamp_raw = row["timestamp"] + except KeyError as e: + raise ValueError( + "DB did not have a timestamp in row, this indicates improper formatting" + ) from e + # this value could also be there, but be NULL, I think. + if timestamp_raw is None: + raise ValueError("timestamp is None check time format and data") + timestamp_iso = datetime.fromisoformat(timestamp_raw) + timestamp = int(timestamp_iso.timestamp()) + + data_points.append( + MetricDataPoint( + timestamp=timestamp, + value=value, + unit=unit, + ) + ) + + metric_series = [MetricSeries(metric=metric_name, labels=all_labels, values=data_points)] + return QueryMetricsResponse(data=metric_series) + + def _get_time_format_for_granularity(self, granularity: str | None) -> str: + """Get the SQLite strftime format string for a given granularity. + Args: + granularity: Granularity string (e.g., "1m", "5m", "1h", "1d") + Returns: + SQLite strftime format string for the granularity + """ + if granularity is None: + raise ValueError("granularity cannot be None for this method - use separate logic for no aggregation") + + if granularity.endswith("d"): + return "%Y-%m-%d 00:00:00" + elif granularity.endswith("h"): + return "%Y-%m-%d %H:00:00" + elif granularity.endswith("m"): + return "%Y-%m-%d %H:%M:00" + else: + return "%Y-%m-%d %H:%M:00" # Default to most granular which will give us the most timestamps. + async def query_traces( self, attribute_filters: list[QueryCondition] | None = None, diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index 4a6958399..8fa5f5f2e 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -9,7 +9,6 @@ from __future__ import annotations # for forward references import hashlib import json import os -import sqlite3 from collections.abc import Generator from contextlib import contextmanager from enum import StrEnum @@ -125,28 +124,13 @@ class ResponseStorage: def __init__(self, test_dir: Path): self.test_dir = test_dir self.responses_dir = self.test_dir / "responses" - self.db_path = self.test_dir / "index.sqlite" self._ensure_directories() - self._init_database() def _ensure_directories(self): self.test_dir.mkdir(parents=True, exist_ok=True) self.responses_dir.mkdir(exist_ok=True) - def _init_database(self): - with sqlite3.connect(self.db_path) as conn: - conn.execute(""" - CREATE TABLE IF NOT EXISTS recordings ( - request_hash TEXT PRIMARY KEY, - response_file TEXT, - endpoint TEXT, - model TEXT, - timestamp TEXT, - is_streaming BOOLEAN - ) - """) - def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]): """Store a request/response pair.""" # Generate unique response filename @@ -169,34 +153,9 @@ class ResponseStorage: f.write("\n") f.flush() - # Update SQLite index - with sqlite3.connect(self.db_path) as conn: - conn.execute( - """ - INSERT OR REPLACE INTO recordings - (request_hash, response_file, endpoint, model, timestamp, is_streaming) - VALUES (?, ?, ?, ?, datetime('now'), ?) - """, - ( - request_hash, - response_file, - request.get("endpoint", ""), - request.get("model", ""), - response.get("is_streaming", False), - ), - ) - def find_recording(self, request_hash: str) -> dict[str, Any] | None: """Find a recorded response by request hash.""" - with sqlite3.connect(self.db_path) as conn: - result = conn.execute( - "SELECT response_file FROM recordings WHERE request_hash = ?", (request_hash,) - ).fetchone() - - if not result: - return None - - response_file = result[0] + response_file = f"{request_hash[:12]}.json" response_path = self.responses_dir / response_file if not response_path.exists(): diff --git a/llama_stack/ui/app/chat-playground/chunk-processor.test.tsx b/llama_stack/ui/app/chat-playground/chunk-processor.test.tsx new file mode 100644 index 000000000..70e8b3afa --- /dev/null +++ b/llama_stack/ui/app/chat-playground/chunk-processor.test.tsx @@ -0,0 +1,610 @@ +import { describe, test, expect } from "@jest/globals"; + +// Extract the exact processChunk function implementation for testing +function createProcessChunk() { + return (chunk: unknown): { text: string | null; isToolCall: boolean } => { + const chunkObj = chunk as Record; + + // Helper function to check if content contains function call JSON + const containsToolCall = (content: string): boolean => { + return ( + content.includes('"type": "function"') || + content.includes('"name": "knowledge_search"') || + content.includes('"parameters":') || + !!content.match(/\{"type":\s*"function".*?\}/) + ); + }; + + // Check if this chunk contains a tool call (function call) + let isToolCall = false; + + // Check direct chunk content if it's a string + if (typeof chunk === "string") { + isToolCall = containsToolCall(chunk); + } + + // Check delta structures + if ( + chunkObj?.delta && + typeof chunkObj.delta === "object" && + chunkObj.delta !== null + ) { + const delta = chunkObj.delta as Record; + if ("tool_calls" in delta) { + isToolCall = true; + } + if (typeof delta.text === "string") { + if (containsToolCall(delta.text)) { + isToolCall = true; + } + } + } + + // Check event structures + if ( + chunkObj?.event && + typeof chunkObj.event === "object" && + chunkObj.event !== null + ) { + const event = chunkObj.event as Record; + + // Check event payload + if ( + event?.payload && + typeof event.payload === "object" && + event.payload !== null + ) { + const payload = event.payload as Record; + if (typeof payload.content === "string") { + if (containsToolCall(payload.content)) { + isToolCall = true; + } + } + + // Check payload delta + if ( + payload?.delta && + typeof payload.delta === "object" && + payload.delta !== null + ) { + const delta = payload.delta as Record; + if (typeof delta.text === "string") { + if (containsToolCall(delta.text)) { + isToolCall = true; + } + } + } + } + + // Check event delta + if ( + event?.delta && + typeof event.delta === "object" && + event.delta !== null + ) { + const delta = event.delta as Record; + if (typeof delta.text === "string") { + if (containsToolCall(delta.text)) { + isToolCall = true; + } + } + if (typeof delta.content === "string") { + if (containsToolCall(delta.content)) { + isToolCall = true; + } + } + } + } + + // if it's a tool call, skip it (don't display in chat) + if (isToolCall) { + return { text: null, isToolCall: true }; + } + + // Extract text content from various chunk formats + let text: string | null = null; + + // Helper function to extract clean text content, filtering out function calls + const extractCleanText = (content: string): string | null => { + if (containsToolCall(content)) { + try { + // Try to parse and extract non-function call parts + const jsonMatch = content.match( + /\{"type":\s*"function"[^}]*\}[^}]*\}/ + ); + if (jsonMatch) { + const jsonPart = jsonMatch[0]; + const parsedJson = JSON.parse(jsonPart); + + // If it's a function call, extract text after JSON + if (parsedJson.type === "function") { + const textAfterJson = content + .substring(content.indexOf(jsonPart) + jsonPart.length) + .trim(); + return textAfterJson || null; + } + } + // If we can't parse it properly, skip the whole thing + return null; + } catch { + return null; + } + } + return content; + }; + + // Try direct delta text + if ( + chunkObj?.delta && + typeof chunkObj.delta === "object" && + chunkObj.delta !== null + ) { + const delta = chunkObj.delta as Record; + if (typeof delta.text === "string") { + text = extractCleanText(delta.text); + } + } + + // Try event structures + if ( + !text && + chunkObj?.event && + typeof chunkObj.event === "object" && + chunkObj.event !== null + ) { + const event = chunkObj.event as Record; + + // Try event payload content + if ( + event?.payload && + typeof event.payload === "object" && + event.payload !== null + ) { + const payload = event.payload as Record; + + // Try direct payload content + if (typeof payload.content === "string") { + text = extractCleanText(payload.content); + } + + // Try turn_complete event structure: payload.turn.output_message.content + if ( + !text && + payload?.turn && + typeof payload.turn === "object" && + payload.turn !== null + ) { + const turn = payload.turn as Record; + if ( + turn?.output_message && + typeof turn.output_message === "object" && + turn.output_message !== null + ) { + const outputMessage = turn.output_message as Record< + string, + unknown + >; + if (typeof outputMessage.content === "string") { + text = extractCleanText(outputMessage.content); + } + } + + // Fallback to model_response in steps if no output_message + if ( + !text && + turn?.steps && + Array.isArray(turn.steps) && + turn.steps.length > 0 + ) { + for (const step of turn.steps) { + if (step && typeof step === "object" && step !== null) { + const stepObj = step as Record; + if ( + stepObj?.model_response && + typeof stepObj.model_response === "object" && + stepObj.model_response !== null + ) { + const modelResponse = stepObj.model_response as Record< + string, + unknown + >; + if (typeof modelResponse.content === "string") { + text = extractCleanText(modelResponse.content); + break; + } + } + } + } + } + } + + // Try payload delta + if ( + !text && + payload?.delta && + typeof payload.delta === "object" && + payload.delta !== null + ) { + const delta = payload.delta as Record; + if (typeof delta.text === "string") { + text = extractCleanText(delta.text); + } + } + } + + // Try event delta + if ( + !text && + event?.delta && + typeof event.delta === "object" && + event.delta !== null + ) { + const delta = event.delta as Record; + if (typeof delta.text === "string") { + text = extractCleanText(delta.text); + } + if (!text && typeof delta.content === "string") { + text = extractCleanText(delta.content); + } + } + } + + // Try choices structure (ChatML format) + if ( + !text && + chunkObj?.choices && + Array.isArray(chunkObj.choices) && + chunkObj.choices.length > 0 + ) { + const choice = chunkObj.choices[0] as Record; + if ( + choice?.delta && + typeof choice.delta === "object" && + choice.delta !== null + ) { + const delta = choice.delta as Record; + if (typeof delta.content === "string") { + text = extractCleanText(delta.content); + } + } + } + + // Try direct string content + if (!text && typeof chunk === "string") { + text = extractCleanText(chunk); + } + + return { text, isToolCall: false }; + }; +} + +describe("Chunk Processor", () => { + const processChunk = createProcessChunk(); + + describe("Real Event Structures", () => { + test("handles turn_complete event with cancellation policy response", () => { + const chunk = { + event: { + payload: { + event_type: "turn_complete", + turn: { + turn_id: "50a2d6b7-49ed-4d1e-b1c2-6d68b3f726db", + session_id: "e7f62b8e-518c-4450-82df-e65fe49f27a3", + input_messages: [ + { + role: "user", + content: "nice, what's the cancellation policy?", + context: null, + }, + ], + steps: [ + { + turn_id: "50a2d6b7-49ed-4d1e-b1c2-6d68b3f726db", + step_id: "54074310-af42-414c-9ffe-fba5b2ead0ad", + started_at: "2025-08-27T18:15:25.870703Z", + completed_at: "2025-08-27T18:15:51.288993Z", + step_type: "inference", + model_response: { + role: "assistant", + content: + "According to the search results, the cancellation policy for Red Hat Summit is as follows:\n\n* Cancellations must be received by 5 PM EDT on April 18, 2025 for a 50% refund of the registration fee.\n* No refunds will be given for cancellations received after 5 PM EDT on April 18, 2025.\n* Cancellation of travel reservations and hotel reservations are the responsibility of the registrant.", + stop_reason: "end_of_turn", + tool_calls: [], + }, + }, + ], + output_message: { + role: "assistant", + content: + "According to the search results, the cancellation policy for Red Hat Summit is as follows:\n\n* Cancellations must be received by 5 PM EDT on April 18, 2025 for a 50% refund of the registration fee.\n* No refunds will be given for cancellations received after 5 PM EDT on April 18, 2025.\n* Cancellation of travel reservations and hotel reservations are the responsibility of the registrant.", + stop_reason: "end_of_turn", + tool_calls: [], + }, + output_attachments: [], + started_at: "2025-08-27T18:15:25.868548Z", + completed_at: "2025-08-27T18:15:51.289262Z", + }, + }, + }, + }; + + const result = processChunk(chunk); + expect(result.isToolCall).toBe(false); + expect(result.text).toContain( + "According to the search results, the cancellation policy for Red Hat Summit is as follows:" + ); + expect(result.text).toContain("5 PM EDT on April 18, 2025"); + }); + + test("handles turn_complete event with address response", () => { + const chunk = { + event: { + payload: { + event_type: "turn_complete", + turn: { + turn_id: "2f4a1520-8ecc-4cb7-bb7b-886939e042b0", + session_id: "e7f62b8e-518c-4450-82df-e65fe49f27a3", + input_messages: [ + { + role: "user", + content: "what's francisco's address", + context: null, + }, + ], + steps: [ + { + turn_id: "2f4a1520-8ecc-4cb7-bb7b-886939e042b0", + step_id: "c13dd277-1acb-4419-8fbf-d5e2f45392ea", + started_at: "2025-08-27T18:14:52.558761Z", + completed_at: "2025-08-27T18:15:11.306032Z", + step_type: "inference", + model_response: { + role: "assistant", + content: + "Francisco Arceo's address is:\n\nRed Hat\nUnited States\n17 Primrose Ln \nBasking Ridge New Jersey 07920", + stop_reason: "end_of_turn", + tool_calls: [], + }, + }, + ], + output_message: { + role: "assistant", + content: + "Francisco Arceo's address is:\n\nRed Hat\nUnited States\n17 Primrose Ln \nBasking Ridge New Jersey 07920", + stop_reason: "end_of_turn", + tool_calls: [], + }, + output_attachments: [], + started_at: "2025-08-27T18:14:52.553707Z", + completed_at: "2025-08-27T18:15:11.306729Z", + }, + }, + }, + }; + + const result = processChunk(chunk); + expect(result.isToolCall).toBe(false); + expect(result.text).toContain("Francisco Arceo's address is:"); + expect(result.text).toContain("17 Primrose Ln"); + expect(result.text).toContain("Basking Ridge New Jersey 07920"); + }); + + test("handles turn_complete event with ticket cost response", () => { + const chunk = { + event: { + payload: { + event_type: "turn_complete", + turn: { + turn_id: "7ef244a3-efee-42ca-a9c8-942865251002", + session_id: "e7f62b8e-518c-4450-82df-e65fe49f27a3", + input_messages: [ + { + role: "user", + content: "what was the ticket cost for summit?", + context: null, + }, + ], + steps: [ + { + turn_id: "7ef244a3-efee-42ca-a9c8-942865251002", + step_id: "7651dda0-315a-472d-b1c1-3c2725f55bc5", + started_at: "2025-08-27T18:14:21.710611Z", + completed_at: "2025-08-27T18:14:39.706452Z", + step_type: "inference", + model_response: { + role: "assistant", + content: + "The ticket cost for the Red Hat Summit was $999.00 for a conference pass.", + stop_reason: "end_of_turn", + tool_calls: [], + }, + }, + ], + output_message: { + role: "assistant", + content: + "The ticket cost for the Red Hat Summit was $999.00 for a conference pass.", + stop_reason: "end_of_turn", + tool_calls: [], + }, + output_attachments: [], + started_at: "2025-08-27T18:14:21.705289Z", + completed_at: "2025-08-27T18:14:39.706752Z", + }, + }, + }, + }; + + const result = processChunk(chunk); + expect(result.isToolCall).toBe(false); + expect(result.text).toBe( + "The ticket cost for the Red Hat Summit was $999.00 for a conference pass." + ); + }); + }); + + describe("Function Call Detection", () => { + test("detects function calls in direct string chunks", () => { + const chunk = + '{"type": "function", "name": "knowledge_search", "parameters": {"query": "test"}}'; + const result = processChunk(chunk); + expect(result.isToolCall).toBe(true); + expect(result.text).toBe(null); + }); + + test("detects function calls in event payload content", () => { + const chunk = { + event: { + payload: { + content: + '{"type": "function", "name": "knowledge_search", "parameters": {"query": "test"}}', + }, + }, + }; + const result = processChunk(chunk); + expect(result.isToolCall).toBe(true); + expect(result.text).toBe(null); + }); + + test("detects tool_calls in delta structure", () => { + const chunk = { + delta: { + tool_calls: [{ function: { name: "knowledge_search" } }], + }, + }; + const result = processChunk(chunk); + expect(result.isToolCall).toBe(true); + expect(result.text).toBe(null); + }); + + test("detects function call in mixed content but skips it", () => { + const chunk = + '{"type": "function", "name": "knowledge_search", "parameters": {"query": "test"}} Based on the search results, here is your answer.'; + const result = processChunk(chunk); + // This is detected as a tool call and skipped entirely - the implementation prioritizes safety + expect(result.isToolCall).toBe(true); + expect(result.text).toBe(null); + }); + }); + + describe("Text Extraction", () => { + test("extracts text from direct string chunks", () => { + const chunk = "Hello, this is a normal response."; + const result = processChunk(chunk); + expect(result.isToolCall).toBe(false); + expect(result.text).toBe("Hello, this is a normal response."); + }); + + test("extracts text from delta structure", () => { + const chunk = { + delta: { + text: "Hello, this is a normal response.", + }, + }; + const result = processChunk(chunk); + expect(result.isToolCall).toBe(false); + expect(result.text).toBe("Hello, this is a normal response."); + }); + + test("extracts text from choices structure", () => { + const chunk = { + choices: [ + { + delta: { + content: "Hello, this is a normal response.", + }, + }, + ], + }; + const result = processChunk(chunk); + expect(result.isToolCall).toBe(false); + expect(result.text).toBe("Hello, this is a normal response."); + }); + + test("prioritizes output_message over model_response in turn structure", () => { + const chunk = { + event: { + payload: { + turn: { + steps: [ + { + model_response: { + content: "Model response content.", + }, + }, + ], + output_message: { + content: "Final output message content.", + }, + }, + }, + }, + }; + const result = processChunk(chunk); + expect(result.isToolCall).toBe(false); + expect(result.text).toBe("Final output message content."); + }); + + test("falls back to model_response when no output_message", () => { + const chunk = { + event: { + payload: { + turn: { + steps: [ + { + model_response: { + content: "This is from the model response.", + }, + }, + ], + }, + }, + }, + }; + const result = processChunk(chunk); + expect(result.isToolCall).toBe(false); + expect(result.text).toBe("This is from the model response."); + }); + }); + + describe("Edge Cases", () => { + test("handles empty chunks", () => { + const result = processChunk(""); + expect(result.isToolCall).toBe(false); + expect(result.text).toBe(""); + }); + + test("handles null chunks", () => { + const result = processChunk(null); + expect(result.isToolCall).toBe(false); + expect(result.text).toBe(null); + }); + + test("handles undefined chunks", () => { + const result = processChunk(undefined); + expect(result.isToolCall).toBe(false); + expect(result.text).toBe(null); + }); + + test("handles chunks with no text content", () => { + const chunk = { + event: { + metadata: { + timestamp: "2024-01-01", + }, + }, + }; + const result = processChunk(chunk); + expect(result.isToolCall).toBe(false); + expect(result.text).toBe(null); + }); + + test("handles malformed JSON in function calls gracefully", () => { + const chunk = + '{"type": "function", "name": "knowledge_search"} incomplete json'; + const result = processChunk(chunk); + expect(result.isToolCall).toBe(true); + expect(result.text).toBe(null); + }); + }); +}); diff --git a/llama_stack/ui/app/chat-playground/page.test.tsx b/llama_stack/ui/app/chat-playground/page.test.tsx new file mode 100644 index 000000000..d9025e523 --- /dev/null +++ b/llama_stack/ui/app/chat-playground/page.test.tsx @@ -0,0 +1,790 @@ +import React from "react"; +import { + render, + screen, + fireEvent, + waitFor, + act, +} from "@testing-library/react"; +import "@testing-library/jest-dom"; +import ChatPlaygroundPage from "./page"; + +const mockClient = { + agents: { + list: jest.fn(), + create: jest.fn(), + retrieve: jest.fn(), + delete: jest.fn(), + session: { + list: jest.fn(), + create: jest.fn(), + delete: jest.fn(), + retrieve: jest.fn(), + }, + turn: { + create: jest.fn(), + }, + }, + models: { + list: jest.fn(), + }, + toolgroups: { + list: jest.fn(), + }, + vectorDBs: { + list: jest.fn(), + }, +}; + +jest.mock("@/hooks/use-auth-client", () => ({ + useAuthClient: jest.fn(() => mockClient), +})); + +jest.mock("@/components/chat-playground/chat", () => ({ + Chat: jest.fn( + ({ + className, + messages, + handleSubmit, + input, + handleInputChange, + isGenerating, + append, + suggestions, + }) => ( +
+
{messages.length}
+ + + {suggestions?.map((suggestion: string, index: number) => ( + + ))} +
+ ) + ), +})); + +jest.mock("@/components/chat-playground/conversations", () => ({ + SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => ( +
+ {selectedAgentId && ( + <> +
{selectedAgentId}
+ + + )} +
+ )), + SessionUtils: { + saveCurrentSessionId: jest.fn(), + loadCurrentSessionId: jest.fn(), + loadCurrentAgentId: jest.fn(), + saveCurrentAgentId: jest.fn(), + clearCurrentSession: jest.fn(), + saveSessionData: jest.fn(), + loadSessionData: jest.fn(), + saveAgentConfig: jest.fn(), + loadAgentConfig: jest.fn(), + clearAgentCache: jest.fn(), + createDefaultSession: jest.fn(() => ({ + id: "test-session-123", + name: "Default Session", + messages: [], + selectedModel: "", + systemMessage: "You are a helpful assistant.", + agentId: "test-agent-123", + createdAt: Date.now(), + updatedAt: Date.now(), + })), + }, +})); + +const mockAgents = [ + { + agent_id: "agent_123", + agent_config: { + name: "Test Agent", + instructions: "You are a test assistant.", + }, + }, + { + agent_id: "agent_456", + agent_config: { + agent_name: "Another Agent", + instructions: "You are another assistant.", + }, + }, +]; + +const mockModels = [ + { + identifier: "test-model-1", + model_type: "llm", + }, + { + identifier: "test-model-2", + model_type: "llm", + }, +]; + +const mockToolgroups = [ + { + identifier: "builtin::rag", + provider_id: "test-provider", + type: "tool_group", + provider_resource_id: "test-resource", + }, +]; + +describe("ChatPlaygroundPage", () => { + beforeEach(() => { + jest.clearAllMocks(); + Element.prototype.scrollIntoView = jest.fn(); + mockClient.agents.list.mockResolvedValue({ data: mockAgents }); + mockClient.models.list.mockResolvedValue(mockModels); + mockClient.toolgroups.list.mockResolvedValue(mockToolgroups); + mockClient.agents.session.create.mockResolvedValue({ + session_id: "new-session-123", + }); + mockClient.agents.session.list.mockResolvedValue({ data: [] }); + mockClient.agents.session.retrieve.mockResolvedValue({ + session_id: "test-session", + session_name: "Test Session", + started_at: new Date().toISOString(), + turns: [], + }); + mockClient.agents.retrieve.mockResolvedValue({ + agent_id: "test-agent", + agent_config: { + toolgroups: ["builtin::rag"], + instructions: "Test instructions", + model: "test-model", + }, + }); + mockClient.agents.delete.mockResolvedValue(undefined); + }); + + describe("Agent Selector Rendering", () => { + test("shows agent selector when agents are available", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByText("Agent Session:")).toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(2); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.getByText("Clear Chat")).toBeInTheDocument(); + }); + }); + + test("does not show agent selector when no agents are available", async () => { + mockClient.agents.list.mockResolvedValue({ data: [] }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(1); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument(); + }); + }); + + test("does not show agent selector while loading", async () => { + mockClient.agents.list.mockImplementation(() => new Promise(() => {})); + + await act(async () => { + render(); + }); + + expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(1); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument(); + }); + + test("shows agent options in selector", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Test Agent") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + expect(screen.getAllByText("Test Agent")).toHaveLength(2); + expect(screen.getByText("Another Agent")).toBeInTheDocument(); + }); + }); + + test("displays agent ID when no name is available", async () => { + const agentWithoutName = { + agent_id: "agent_789", + agent_config: { + instructions: "You are an agent without a name.", + }, + }; + + mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Agent agent_78") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2); + }); + }); + }); + + describe("Agent Creation Modal", () => { + test("opens agent creation modal when + New Agent is clicked", async () => { + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + fireEvent.click(newAgentButton); + + expect(screen.getByText("Create New Agent")).toBeInTheDocument(); + expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument(); + expect(screen.getAllByText("Model")).toHaveLength(2); + expect(screen.getByText("System Instructions")).toBeInTheDocument(); + expect(screen.getByText("Tools (optional)")).toBeInTheDocument(); + }); + + test("closes modal when Cancel is clicked", async () => { + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + fireEvent.click(newAgentButton); + + const cancelButton = screen.getByText("Cancel"); + fireEvent.click(cancelButton); + + expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument(); + }); + + test("creates agent when Create Agent is clicked", async () => { + mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" }); + mockClient.agents.list + .mockResolvedValueOnce({ data: mockAgents }) + .mockResolvedValueOnce({ + data: [ + ...mockAgents, + { agent_id: "new-agent-123", agent_config: { name: "New Agent" } }, + ], + }); + + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + await act(async () => { + fireEvent.click(newAgentButton); + }); + + await waitFor(() => { + expect(screen.getByText("Create New Agent")).toBeInTheDocument(); + }); + + const nameInput = screen.getByPlaceholderText("My Custom Agent"); + await act(async () => { + fireEvent.change(nameInput, { target: { value: "Test Agent Name" } }); + }); + + const instructionsTextarea = screen.getByDisplayValue( + "You are a helpful assistant." + ); + await act(async () => { + fireEvent.change(instructionsTextarea, { + target: { value: "Custom instructions" }, + }); + }); + + await waitFor(() => { + const modalModelSelectors = screen + .getAllByRole("combobox") + .filter(el => { + return ( + el.textContent?.includes("Select Model") || + el.closest('[class*="modal"]') || + el.closest('[class*="card"]') + ); + }); + expect(modalModelSelectors.length).toBeGreaterThan(0); + }); + + const modalModelSelectors = screen.getAllByRole("combobox").filter(el => { + return ( + el.textContent?.includes("Select Model") || + el.closest('[class*="modal"]') || + el.closest('[class*="card"]') + ); + }); + + await act(async () => { + fireEvent.click(modalModelSelectors[0]); + }); + + await waitFor(() => { + const modelOptions = screen.getAllByText("test-model-1"); + expect(modelOptions.length).toBeGreaterThan(0); + }); + + const modelOptions = screen.getAllByText("test-model-1"); + const dropdownOption = modelOptions.find( + option => + option.closest('[role="option"]') || + option.id?.includes("radix") || + option.getAttribute("aria-selected") !== null + ); + + await act(async () => { + fireEvent.click( + dropdownOption || modelOptions[modelOptions.length - 1] + ); + }); + + await waitFor(() => { + const createButton = screen.getByText("Create Agent"); + expect(createButton).not.toBeDisabled(); + }); + + const createButton = screen.getByText("Create Agent"); + await act(async () => { + fireEvent.click(createButton); + }); + + await waitFor(() => { + expect(mockClient.agents.create).toHaveBeenCalledWith({ + agent_config: { + model: expect.any(String), + instructions: "Custom instructions", + name: "Test Agent Name", + enable_session_persistence: true, + }, + }); + }); + + await waitFor(() => { + expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument(); + }); + }); + }); + + describe("Agent Selection", () => { + test("creates default session when agent is selected", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(mockClient.agents.session.create).toHaveBeenCalledWith( + "agent_123", + { session_name: "Default Session" } + ); + }); + }); + + test("switches agent when different agent is selected", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Test Agent") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + const anotherAgentOption = screen.getByText("Another Agent"); + fireEvent.click(anotherAgentOption); + }); + + expect(mockClient.agents.session.create).toHaveBeenCalledWith( + "agent_456", + { session_name: "Default Session" } + ); + }); + }); + + describe("Agent Deletion", () => { + test("shows delete button when multiple agents exist", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + }); + + test("shows delete button even when only one agent exists", async () => { + mockClient.agents.list.mockResolvedValue({ + data: [mockAgents[0]], + }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + }); + + test("deletes agent and switches to another when confirmed", async () => { + global.confirm = jest.fn(() => true); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + + mockClient.agents.delete.mockResolvedValue(undefined); + mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents }); + mockClient.agents.list.mockResolvedValueOnce({ + data: [mockAgents[1]], + }); + + const deleteButton = screen.getByTitle("Delete current agent"); + await act(async () => { + deleteButton.click(); + }); + + await waitFor(() => { + expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123"); + expect(global.confirm).toHaveBeenCalledWith( + "Are you sure you want to delete this agent? This action cannot be undone and will delete the agent and all its sessions." + ); + }); + + (global.confirm as jest.Mock).mockRestore(); + }); + + test("does not delete agent when cancelled", async () => { + global.confirm = jest.fn(() => false); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + + const deleteButton = screen.getByTitle("Delete current agent"); + await act(async () => { + deleteButton.click(); + }); + + await waitFor(() => { + expect(global.confirm).toHaveBeenCalled(); + expect(mockClient.agents.delete).not.toHaveBeenCalled(); + }); + + (global.confirm as jest.Mock).mockRestore(); + }); + }); + + describe("Error Handling", () => { + test("handles agent loading errors gracefully", async () => { + mockClient.agents.list.mockRejectedValue( + new Error("Failed to load agents") + ); + const consoleSpy = jest + .spyOn(console, "error") + .mockImplementation(() => {}); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(consoleSpy).toHaveBeenCalledWith( + "Error fetching agents:", + expect.any(Error) + ); + }); + + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + + consoleSpy.mockRestore(); + }); + + test("handles model loading errors gracefully", async () => { + mockClient.models.list.mockRejectedValue( + new Error("Failed to load models") + ); + const consoleSpy = jest + .spyOn(console, "error") + .mockImplementation(() => {}); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(consoleSpy).toHaveBeenCalledWith( + "Error fetching models:", + expect.any(Error) + ); + }); + + consoleSpy.mockRestore(); + }); + }); + + describe("RAG File Upload", () => { + let mockFileReader: { + readAsDataURL: jest.Mock; + readAsText: jest.Mock; + result: string | null; + onload: (() => void) | null; + onerror: (() => void) | null; + }; + let mockRAGTool: { + insert: jest.Mock; + }; + + beforeEach(() => { + mockFileReader = { + readAsDataURL: jest.fn(), + readAsText: jest.fn(), + result: null, + onload: null, + onerror: null, + }; + global.FileReader = jest.fn(() => mockFileReader); + + mockRAGTool = { + insert: jest.fn().mockResolvedValue({}), + }; + mockClient.toolRuntime = { + ragTool: mockRAGTool, + }; + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + test("handles text file upload", async () => { + new File(["Hello, world!"], "test.txt", { + type: "text/plain", + }); + + mockClient.agents.retrieve.mockResolvedValue({ + agent_id: "test-agent", + agent_config: { + toolgroups: [ + { + name: "builtin::rag/knowledge_search", + args: { vector_db_ids: ["test-vector-db"] }, + }, + ], + }, + }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTestId("chat-component")).toBeInTheDocument(); + }); + + const chatComponent = screen.getByTestId("chat-component"); + chatComponent.getAttribute("data-onragfileupload"); + + // this is a simplified test + expect(mockRAGTool.insert).not.toHaveBeenCalled(); + }); + + test("handles PDF file upload with FileReader", async () => { + new File([new ArrayBuffer(1000)], "test.pdf", { + type: "application/pdf", + }); + + const mockDataURL = "data:application/pdf;base64,JVBERi0xLjQK"; + mockFileReader.result = mockDataURL; + + mockClient.agents.retrieve.mockResolvedValue({ + agent_id: "test-agent", + agent_config: { + toolgroups: [ + { + name: "builtin::rag/knowledge_search", + args: { vector_db_ids: ["test-vector-db"] }, + }, + ], + }, + }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTestId("chat-component")).toBeInTheDocument(); + }); + + expect(global.FileReader).toBeDefined(); + }); + + test("handles different file types correctly", () => { + const getContentType = (filename: string): string => { + const ext = filename.toLowerCase().split(".").pop(); + switch (ext) { + case "pdf": + return "application/pdf"; + case "txt": + return "text/plain"; + case "md": + return "text/markdown"; + case "html": + return "text/html"; + case "csv": + return "text/csv"; + case "json": + return "application/json"; + case "docx": + return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"; + case "doc": + return "application/msword"; + default: + return "application/octet-stream"; + } + }; + + expect(getContentType("test.pdf")).toBe("application/pdf"); + expect(getContentType("test.txt")).toBe("text/plain"); + expect(getContentType("test.md")).toBe("text/markdown"); + expect(getContentType("test.html")).toBe("text/html"); + expect(getContentType("test.csv")).toBe("text/csv"); + expect(getContentType("test.json")).toBe("application/json"); + expect(getContentType("test.docx")).toBe( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ); + expect(getContentType("test.doc")).toBe("application/msword"); + expect(getContentType("test.unknown")).toBe("application/octet-stream"); + }); + + test("determines text vs binary file types correctly", () => { + const isTextFile = (mimeType: string): boolean => { + return ( + mimeType.startsWith("text/") || + mimeType === "application/json" || + mimeType === "text/markdown" || + mimeType === "text/html" || + mimeType === "text/csv" + ); + }; + + expect(isTextFile("text/plain")).toBe(true); + expect(isTextFile("text/markdown")).toBe(true); + expect(isTextFile("text/html")).toBe(true); + expect(isTextFile("text/csv")).toBe(true); + expect(isTextFile("application/json")).toBe(true); + + expect(isTextFile("application/pdf")).toBe(false); + expect(isTextFile("application/msword")).toBe(false); + expect( + isTextFile( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ) + ).toBe(false); + expect(isTextFile("application/octet-stream")).toBe(false); + }); + + test("handles FileReader error gracefully", async () => { + const pdfFile = new File([new ArrayBuffer(1000)], "test.pdf", { + type: "application/pdf", + }); + + mockFileReader.onerror = jest.fn(); + const mockError = new Error("FileReader failed"); + + const fileReaderPromise = new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result as string); + reader.onerror = () => reject(reader.error || mockError); + reader.readAsDataURL(pdfFile); + + setTimeout(() => { + reader.onerror?.(new ProgressEvent("error")); + }, 0); + }); + + await expect(fileReaderPromise).rejects.toBeDefined(); + }); + + test("handles large file upload with FileReader approach", () => { + // create a large file + const largeFile = new File( + [new ArrayBuffer(10 * 1024 * 1024)], + "large.pdf", + { + type: "application/pdf", + } + ); + + expect(largeFile.size).toBe(10 * 1024 * 1024); // 10MB + + expect(global.FileReader).toBeDefined(); + + const reader = new FileReader(); + expect(reader.readAsDataURL).toBeDefined(); + }); + }); +}); diff --git a/llama_stack/ui/app/chat-playground/page.tsx b/llama_stack/ui/app/chat-playground/page.tsx index b8651aca0..0417f7083 100644 --- a/llama_stack/ui/app/chat-playground/page.tsx +++ b/llama_stack/ui/app/chat-playground/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect } from "react"; +import { useState, useEffect, useCallback, useRef } from "react"; import { flushSync } from "react-dom"; import { Button } from "@/components/ui/button"; import { @@ -10,14 +10,27 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; +import { Card } from "@/components/ui/card"; +import { Input } from "@/components/ui/input"; +import { Trash2 } from "lucide-react"; import { Chat } from "@/components/chat-playground/chat"; import { type Message } from "@/components/chat-playground/chat-message"; +import { VectorDBCreator } from "@/components/chat-playground/vector-db-creator"; import { useAuthClient } from "@/hooks/use-auth-client"; -import type { CompletionCreateParams } from "llama-stack-client/resources/chat/completions"; import type { Model } from "llama-stack-client/resources/models"; - +import type { TurnCreateParams } from "llama-stack-client/resources/agents/turn"; +import { + SessionUtils, + type ChatSession, +} from "@/components/chat-playground/conversations"; +import { + cleanMessageContent, + extractCleanText, +} from "@/lib/message-content-utils"; export default function ChatPlaygroundPage() { - const [messages, setMessages] = useState([]); + const [currentSession, setCurrentSession] = useState( + null + ); const [input, setInput] = useState(""); const [isGenerating, setIsGenerating] = useState(false); const [error, setError] = useState(null); @@ -25,20 +38,651 @@ export default function ChatPlaygroundPage() { const [selectedModel, setSelectedModel] = useState(""); const [modelsLoading, setModelsLoading] = useState(true); const [modelsError, setModelsError] = useState(null); + const [agents, setAgents] = useState< + Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }> + >([]); + const [selectedAgentConfig, setSelectedAgentConfig] = useState<{ + toolgroups?: Array< + string | { name: string; args: Record } + >; + } | null>(null); + const [selectedAgentId, setSelectedAgentId] = useState(""); + const [agentsLoading, setAgentsLoading] = useState(true); + const [showCreateAgent, setShowCreateAgent] = useState(false); + const [newAgentName, setNewAgentName] = useState(""); + const [newAgentInstructions, setNewAgentInstructions] = useState( + "You are a helpful assistant." + ); + const [selectedToolgroups, setSelectedToolgroups] = useState([]); + const [availableToolgroups, setAvailableToolgroups] = useState< + Array<{ + identifier: string; + provider_id: string; + type: string; + provider_resource_id?: string; + }> + >([]); + const [showCreateVectorDB, setShowCreateVectorDB] = useState(false); + const [availableVectorDBs, setAvailableVectorDBs] = useState< + Array<{ + identifier: string; + vector_db_name?: string; + embedding_model: string; + }> + >([]); + const [uploadNotification, setUploadNotification] = useState<{ + show: boolean; + message: string; + type: "success" | "error" | "loading"; + }>({ show: false, message: "", type: "success" }); + const [selectedVectorDBs, setSelectedVectorDBs] = useState([]); const client = useAuthClient(); + const abortControllerRef = useRef(null); const isModelsLoading = modelsLoading ?? true; + const loadAgentConfig = useCallback( + async (agentId: string) => { + try { + // try to load from cache first + const cachedConfig = SessionUtils.loadAgentConfig(agentId); + if (cachedConfig) { + setSelectedAgentConfig({ + toolgroups: cachedConfig.toolgroups, + }); + return; + } + + const agentDetails = await client.agents.retrieve(agentId); + + // cache config + SessionUtils.saveAgentConfig(agentId, { + ...agentDetails.agent_config, + toolgroups: agentDetails.agent_config?.toolgroups, + }); + + setSelectedAgentConfig({ + toolgroups: agentDetails.agent_config?.toolgroups, + }); + } catch (error) { + console.error("Error loading agent config:", error); + setSelectedAgentConfig(null); + } + }, + [client] + ); + + const createDefaultSession = useCallback( + async (agentId: string) => { + try { + const response = await client.agents.session.create(agentId, { + session_name: "Default Session", + }); + + const defaultSession: ChatSession = { + id: response.session_id, + name: "Default Session", + messages: [], + selectedModel: selectedModel, // use current selected model + systemMessage: "You are a helpful assistant.", + agentId, + createdAt: Date.now(), + updatedAt: Date.now(), + }; + + setCurrentSession(defaultSession); + SessionUtils.saveCurrentSessionId(defaultSession.id, agentId); + // cache entire session data + SessionUtils.saveSessionData(agentId, defaultSession); + } catch (error) { + console.error("Error creating default session:", error); + } + }, + [client, selectedModel] + ); + + const loadSessionMessages = useCallback( + async (agentId: string, sessionId: string): Promise => { + try { + const session = await client.agents.session.retrieve( + agentId, + sessionId + ); + + if (!session || !session.turns || !Array.isArray(session.turns)) { + return []; + } + + const messages: Message[] = []; + for (const turn of session.turns) { + if (turn.input_messages && Array.isArray(turn.input_messages)) { + for (const input of turn.input_messages) { + if (input.role === "user" && input.content) { + messages.push({ + id: `${turn.turn_id}-user-${messages.length}`, + role: "user", + content: + typeof input.content === "string" + ? input.content + : JSON.stringify(input.content), + createdAt: new Date(turn.started_at || Date.now()), + }); + } + } + } + + if (turn.output_message && turn.output_message.content) { + console.log("Raw message content:", turn.output_message.content); + console.log("Content type:", typeof turn.output_message.content); + + const cleanContent = cleanMessageContent( + turn.output_message.content + ); + + messages.push({ + id: `${turn.turn_id}-assistant-${messages.length}`, + role: "assistant", + content: cleanContent, + createdAt: new Date( + turn.completed_at || turn.started_at || Date.now() + ), + }); + } + } + + return messages; + } catch (error) { + console.error("Error loading session messages:", error); + return []; + } + }, + [client] + ); + + const loadAgentSessions = useCallback( + async (agentId: string) => { + try { + const response = await client.agents.session.list(agentId); + + if ( + response.data && + Array.isArray(response.data) && + response.data.length > 0 + ) { + // check for saved session ID for this agent + const savedSessionId = SessionUtils.loadCurrentSessionId(agentId); + // try to load cached agent session data first + if (savedSessionId) { + const cachedSession = SessionUtils.loadSessionData( + agentId, + savedSessionId + ); + if (cachedSession) { + setCurrentSession(cachedSession); + SessionUtils.saveCurrentSessionId(cachedSession.id, agentId); + return; + } + console.log("📡 Cache miss, fetching session from API..."); + } + + let sessionToLoad = response.data[0] as { + session_id: string; + session_name?: string; + started_at?: string; + }; + console.log( + "Default session to load (first in list):", + sessionToLoad.session_id + ); + + // try to find saved session id in available sessions + if (savedSessionId) { + const foundSession = response.data.find( + (s: { [key: string]: unknown }) => + (s as { session_id: string }).session_id === savedSessionId + ); + console.log("Found saved session in list:", foundSession); + if (foundSession) { + sessionToLoad = foundSession as { + session_id: string; + session_name?: string; + started_at?: string; + }; + console.log( + "✅ Restored previously selected session:", + savedSessionId + ); + } else { + console.log( + "❌ Previously selected session not found, using latest session" + ); + } + } else { + console.log("❌ No saved session ID found, using latest session"); + } + + const messages = await loadSessionMessages( + agentId, + sessionToLoad.session_id + ); + + const session: ChatSession = { + id: sessionToLoad.session_id, + name: sessionToLoad.session_name || "Session", + messages, + selectedModel: selectedModel || "", + systemMessage: "You are a helpful assistant.", + agentId, + createdAt: sessionToLoad.started_at + ? new Date(sessionToLoad.started_at).getTime() + : Date.now(), + updatedAt: Date.now(), + }; + + setCurrentSession(session); + console.log(`💾 Saving session ID for agent ${agentId}:`, session.id); + SessionUtils.saveCurrentSessionId(session.id, agentId); + // cache session data + SessionUtils.saveSessionData(agentId, session); + } else { + // no sessions, create a new one + await createDefaultSession(agentId); + } + } catch (error) { + console.error("Error loading agent sessions:", error); + // fallback to creating a new session + await createDefaultSession(agentId); + } + }, + [client, loadSessionMessages, createDefaultSession, selectedModel] + ); + + useEffect(() => { + const fetchAgents = async () => { + try { + setAgentsLoading(true); + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + if (agentList.data && agentList.data.length > 0) { + // check if there's a previously selected agent + const savedAgentId = SessionUtils.loadCurrentAgentId(); + + let agentToSelect = agentList.data[0] as { + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }; + + // if we have a saved agent ID, find it in the available agents + if (savedAgentId) { + const foundAgent = agentList.data.find( + (a: { [key: string]: unknown }) => + (a as { agent_id: string }).agent_id === savedAgentId + ); + if (foundAgent) { + agentToSelect = foundAgent as typeof agentToSelect; + } else { + console.log("Previously slelected agent not found:"); + } + } + setSelectedAgentId(agentToSelect.agent_id); + SessionUtils.saveCurrentAgentId(agentToSelect.agent_id); + // load agent config immediately + await loadAgentConfig(agentToSelect.agent_id); + // Note: loadAgentSessions will be called after models are loaded + } + } catch (error) { + console.error("Error fetching agents:", error); + } finally { + setAgentsLoading(false); + } + }; + + fetchAgents(); + + const fetchToolgroups = async () => { + try { + const toolgroups = await client.toolgroups.list(); + + const toolGroupsArray = Array.isArray(toolgroups) + ? toolgroups + : toolgroups && + typeof toolgroups === "object" && + "data" in toolgroups && + Array.isArray((toolgroups as { data: unknown }).data) + ? ( + toolgroups as { + data: Array<{ + identifier: string; + provider_id: string; + type: string; + provider_resource_id?: string; + }>; + } + ).data + : []; + + if (toolGroupsArray && Array.isArray(toolGroupsArray)) { + setAvailableToolgroups(toolGroupsArray); + } else { + console.error("Invalid toolgroups data format:", toolgroups); + } + } catch (error) { + console.error("Error fetching toolgroups:", error); + if (error instanceof Error) { + console.error("Error details:", { + name: error.name, + message: error.message, + stack: error.stack, + }); + } + } + }; + + fetchToolgroups(); + + const fetchVectorDBs = async () => { + try { + const vectorDBs = await client.vectorDBs.list(); + + const vectorDBsArray = Array.isArray(vectorDBs) ? vectorDBs : []; + + if (vectorDBsArray && Array.isArray(vectorDBsArray)) { + setAvailableVectorDBs(vectorDBsArray); + } else { + console.error("Invalid vector DBs data format:", vectorDBs); + } + } catch (error) { + console.error("Error fetching vector DBs:", error); + } + }; + + fetchVectorDBs(); + }, [client, loadAgentSessions, loadAgentConfig]); + + const createNewAgent = useCallback( + async ( + name: string, + instructions: string, + model: string, + toolgroups: string[] = [], + vectorDBs: string[] = [] + ) => { + try { + const processedToolgroups = toolgroups.map(toolgroup => { + if (toolgroup === "builtin::rag" && vectorDBs.length > 0) { + return { + name: "builtin::rag/knowledge_search", + args: { + vector_db_ids: vectorDBs, + }, + }; + } + return toolgroup; + }); + + const agentConfig = { + model, + instructions, + name: name || undefined, + enable_session_persistence: true, + toolgroups: + processedToolgroups.length > 0 ? processedToolgroups : undefined, + }; + + const response = await client.agents.create({ + agent_config: agentConfig, + }); + + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + setSelectedAgentId(response.agent_id); + await loadAgentConfig(response.agent_id); + await loadAgentSessions(response.agent_id); + + return response.agent_id; + } catch (error) { + console.error("Error creating agent:", error); + throw error; + } + }, + [client, loadAgentSessions, loadAgentConfig] + ); + + const handleVectorDBCreated = useCallback( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + async (_vectorDbId: string) => { + setShowCreateVectorDB(false); + + try { + const vectorDBs = await client.vectorDBs.list(); + const vectorDBsArray = Array.isArray(vectorDBs) ? vectorDBs : []; + + if (vectorDBsArray && Array.isArray(vectorDBsArray)) { + setAvailableVectorDBs(vectorDBsArray); + } + } catch (error) { + console.error("Error refreshing vector DBs:", error); + } + }, + [client] + ); + + const deleteAgent = useCallback( + async (agentId: string) => { + if ( + confirm( + "Are you sure you want to delete this agent? This action cannot be undone and will delete the agent and all its sessions." + ) + ) { + try { + // there's a known error where the delete API returns 500 even on success + try { + await client.agents.delete(agentId); + console.log("Agent deleted successfully"); + } catch (deleteError) { + // log the error but don't re-throw - we know deletion succeeded + console.log( + "Agent delete API returned error (but deletion likely succeeded):", + deleteError + ); + } + + SessionUtils.clearAgentCache(agentId); + + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + // if we delete current agent, switch to another + if (selectedAgentId === agentId) { + const remainingAgents = agentList.data?.filter( + (a: { [key: string]: unknown }) => + (a as { agent_id: string }).agent_id !== agentId + ); + if (remainingAgents && remainingAgents.length > 0) { + const newAgent = remainingAgents[0] as { + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }; + setSelectedAgentId(newAgent.agent_id); + SessionUtils.saveCurrentAgentId(newAgent.agent_id); + await loadAgentConfig(newAgent.agent_id); + await loadAgentSessions(newAgent.agent_id); + } else { + // no agents left + setSelectedAgentId(""); + setCurrentSession(null); + setSelectedAgentConfig(null); + } + } + } catch (error) { + console.error("Error deleting agent:", error); + + // check if this is known server bug where deletion succeeds but returns 500 + // The error message will typically contain status codes or "Could not find agent" + const errorMessage = + error instanceof Error ? error.message : String(error); + const isKnownServerBug = + errorMessage.includes("500") || + errorMessage.includes("Internal Server Error") || + errorMessage.includes("Could not find agent") || + errorMessage.includes("400"); + + if (isKnownServerBug) { + console.log( + "Agent deletion succeeded despite error, cleaning up UI" + ); + SessionUtils.clearAgentCache(agentId); + try { + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + if (selectedAgentId === agentId) { + const remainingAgents = agentList.data?.filter( + (a: { [key: string]: unknown }) => + (a as { agent_id: string }).agent_id !== agentId + ); + if (remainingAgents && remainingAgents.length > 0) { + const newAgent = remainingAgents[0] as { + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }; + setSelectedAgentId(newAgent.agent_id); + SessionUtils.saveCurrentAgentId(newAgent.agent_id); + await loadAgentConfig(newAgent.agent_id); + await loadAgentSessions(newAgent.agent_id); + } else { + // no agents left + setSelectedAgentId(""); + setCurrentSession(null); + setSelectedAgentConfig(null); + } + } + } catch (refreshError) { + console.error("Error refreshing agents list:", refreshError); + } + } else { + // show error that we don't know about to user + console.error("Unexpected error during agent deletion:", error); + if (error instanceof Error) { + alert(`Failed to delete agent: ${error.message}`); + } + } + } + } + }, + [client, selectedAgentId, loadAgentConfig, loadAgentSessions] + ); + + const handleModelChange = useCallback((newModel: string) => { + setSelectedModel(newModel); + setCurrentSession(prev => + prev + ? { + ...prev, + selectedModel: newModel, + updatedAt: Date.now(), + } + : prev + ); + }, []); + + useEffect(() => { + if (currentSession) { + SessionUtils.saveCurrentSessionId( + currentSession.id, + currentSession.agentId + ); + // cache session data + SessionUtils.saveSessionData(currentSession.agentId, currentSession); + // only update selectedModel if the session has a valid model and it's different from current + if ( + currentSession.selectedModel && + currentSession.selectedModel !== selectedModel + ) { + setSelectedModel(currentSession.selectedModel); + } + } + }, [currentSession, selectedModel]); + useEffect(() => { const fetchModels = async () => { try { setModelsLoading(true); setModelsError(null); const modelList = await client.models.list(); + + // store all models (including embedding models for vector DB creation) + setModels(modelList); + + // set default LLM model for chat const llmModels = modelList.filter(model => model.model_type === "llm"); - setModels(llmModels); if (llmModels.length > 0) { - setSelectedModel(llmModels[0].identifier); + handleModelChange(llmModels[0].identifier); } } catch (err) { console.error("Error fetching models:", err); @@ -49,39 +693,27 @@ export default function ChatPlaygroundPage() { }; fetchModels(); - }, [client]); + }, [client, handleModelChange]); - const extractTextContent = (content: unknown): string => { - if (typeof content === "string") { - return content; - } - if (Array.isArray(content)) { - return content - .filter( - item => - item && - typeof item === "object" && - "type" in item && - item.type === "text" - ) - .map(item => - item && typeof item === "object" && "text" in item - ? String(item.text) - : "" - ) - .join(""); - } + // load agent sessions after both agents and models are ready + useEffect(() => { if ( - content && - typeof content === "object" && - "type" in content && - content.type === "text" && - "text" in content + selectedAgentId && + !agentsLoading && + !modelsLoading && + selectedModel && + !currentSession ) { - return String(content.text) || ""; + loadAgentSessions(selectedAgentId); } - return ""; - }; + }, [ + selectedAgentId, + agentsLoading, + modelsLoading, + selectedModel, + currentSession, + loadAgentSessions, + ]); const handleInputChange = (e: React.ChangeEvent) => { setInput(e.target.value); @@ -91,7 +723,6 @@ export default function ChatPlaygroundPage() { event?.preventDefault?.(); if (!input.trim()) return; - // Add user message to chat const userMessage: Message = { id: Date.now().toString(), role: "user", @@ -99,40 +730,55 @@ export default function ChatPlaygroundPage() { createdAt: new Date(), }; - setMessages(prev => [...prev, userMessage]); + setCurrentSession(prev => { + if (!prev) return prev; + const updatedSession = { + ...prev, + messages: [...prev.messages, userMessage], + updatedAt: Date.now(), + }; + // update cache with new message + SessionUtils.saveSessionData(prev.agentId, updatedSession); + return updatedSession; + }); setInput(""); - // Use the helper function with the content await handleSubmitWithContent(userMessage.content); }; const handleSubmitWithContent = async (content: string) => { + if (!currentSession || !selectedAgentId) return; + setIsGenerating(true); setError(null); - try { - const messageParams: CompletionCreateParams["messages"] = [ - ...messages.map(msg => { - const msgContent = - typeof msg.content === "string" - ? msg.content - : extractTextContent(msg.content); - if (msg.role === "user") { - return { role: "user" as const, content: msgContent }; - } else if (msg.role === "assistant") { - return { role: "assistant" as const, content: msgContent }; - } else { - return { role: "system" as const, content: msgContent }; - } - }), - { role: "user" as const, content }, - ]; + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } - const response = await client.chat.completions.create({ - model: selectedModel, - messages: messageParams, + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + try { + const userMessage = { + role: "user" as const, + content, + }; + + const turnParams: TurnCreateParams = { + messages: [userMessage], stream: true, - }); + }; + + const response = await client.agents.turn.create( + selectedAgentId, + currentSession.id, + turnParams, + { + signal: abortController.signal, + timeout: 300000, // 5 minutes timeout for RAG queries + } as { signal: AbortSignal; timeout: number } + ); const assistantMessage: Message = { id: (Date.now() + 1).toString(), @@ -141,31 +787,338 @@ export default function ChatPlaygroundPage() { createdAt: new Date(), }; - setMessages(prev => [...prev, assistantMessage]); + const processChunk = ( + chunk: unknown + ): { text: string | null; isToolCall: boolean } => { + const chunkObj = chunk as Record; + + // helper to check if content contains function call JSON + const containsToolCall = (content: string): boolean => { + return ( + content.includes('"type": "function"') || + content.includes('"name": "knowledge_search"') || + content.includes('"parameters":') || + !!content.match(/\{"type":\s*"function".*?\}/) + ); + }; + + let isToolCall = false; + let potentialContent = ""; + + if (typeof chunk === "string") { + potentialContent = chunk; + isToolCall = containsToolCall(chunk); + } + + if ( + chunkObj?.delta && + typeof chunkObj.delta === "object" && + chunkObj.delta !== null + ) { + const delta = chunkObj.delta as Record; + if ("tool_calls" in delta) { + isToolCall = true; + } + if (typeof delta.text === "string") { + potentialContent = delta.text; + if (containsToolCall(delta.text)) { + isToolCall = true; + } + } + } + + if ( + chunkObj?.event && + typeof chunkObj.event === "object" && + chunkObj.event !== null + ) { + const event = chunkObj.event as Record; + + if ( + event?.payload && + typeof event.payload === "object" && + event.payload !== null + ) { + const payload = event.payload as Record; + if (typeof payload.content === "string") { + potentialContent = payload.content; + if (containsToolCall(payload.content)) { + isToolCall = true; + } + } + + if ( + payload?.delta && + typeof payload.delta === "object" && + payload.delta !== null + ) { + const delta = payload.delta as Record; + if (typeof delta.text === "string") { + potentialContent = delta.text; + if (containsToolCall(delta.text)) { + isToolCall = true; + } + } + } + } + + if ( + event?.delta && + typeof event.delta === "object" && + event.delta !== null + ) { + const delta = event.delta as Record; + if (typeof delta.text === "string") { + potentialContent = delta.text; + if (containsToolCall(delta.text)) { + isToolCall = true; + } + } + if (typeof delta.content === "string") { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + potentialContent = delta.content; + if (containsToolCall(delta.content)) { + isToolCall = true; + } + } + } + } + + // if it's a tool call, skip it (don't display in chat) + if (isToolCall) { + return { text: null, isToolCall: true }; + } + + let text: string | null = null; + + if ( + chunkObj?.delta && + typeof chunkObj.delta === "object" && + chunkObj.delta !== null + ) { + const delta = chunkObj.delta as Record; + if (typeof delta.text === "string") { + text = extractCleanText(delta.text); + } + } + + if ( + !text && + chunkObj?.event && + typeof chunkObj.event === "object" && + chunkObj.event !== null + ) { + const event = chunkObj.event as Record; + + if ( + event?.payload && + typeof event.payload === "object" && + event.payload !== null + ) { + const payload = event.payload as Record; + + if (typeof payload.content === "string") { + text = extractCleanText(payload.content); + } + + if ( + !text && + payload?.turn && + typeof payload.turn === "object" && + payload.turn !== null + ) { + const turn = payload.turn as Record; + if ( + turn?.output_message && + typeof turn.output_message === "object" && + turn.output_message !== null + ) { + const outputMessage = turn.output_message as Record< + string, + unknown + >; + if (typeof outputMessage.content === "string") { + text = extractCleanText(outputMessage.content); + } + } + + if ( + !text && + turn?.steps && + Array.isArray(turn.steps) && + turn.steps.length > 0 + ) { + for (const step of turn.steps) { + if (step && typeof step === "object" && step !== null) { + const stepObj = step as Record; + if ( + stepObj?.model_response && + typeof stepObj.model_response === "object" && + stepObj.model_response !== null + ) { + const modelResponse = stepObj.model_response as Record< + string, + unknown + >; + if (typeof modelResponse.content === "string") { + text = extractCleanText(modelResponse.content); + break; + } + } + } + } + } + } + + if ( + !text && + payload?.delta && + typeof payload.delta === "object" && + payload.delta !== null + ) { + const delta = payload.delta as Record; + if (typeof delta.text === "string") { + text = extractCleanText(delta.text); + } + } + } + + if ( + !text && + event?.delta && + typeof event.delta === "object" && + event.delta !== null + ) { + const delta = event.delta as Record; + if (typeof delta.text === "string") { + text = extractCleanText(delta.text); + } + if (!text && typeof delta.content === "string") { + text = extractCleanText(delta.content); + } + } + } + + if ( + !text && + chunkObj?.choices && + Array.isArray(chunkObj.choices) && + chunkObj.choices.length > 0 + ) { + const choice = chunkObj.choices[0] as Record; + if ( + choice?.delta && + typeof choice.delta === "object" && + choice.delta !== null + ) { + const delta = choice.delta as Record; + if (typeof delta.content === "string") { + text = extractCleanText(delta.content); + } + } + } + + if (!text && typeof chunk === "string") { + text = extractCleanText(chunk); + } + + return { text, isToolCall: false }; + }; + setCurrentSession(prev => { + if (!prev) return null; + const updatedSession = { + ...prev, + messages: [...prev.messages, assistantMessage], + updatedAt: Date.now(), + }; + // update cache with assistant message + SessionUtils.saveSessionData(prev.agentId, updatedSession); + return updatedSession; + }); + let fullContent = ""; + for await (const chunk of response) { - if (chunk.choices && chunk.choices[0]?.delta?.content) { - const deltaContent = chunk.choices[0].delta.content; - fullContent += deltaContent; + const { text: deltaText } = processChunk(chunk); + + // logging for debugging function calls + // if (deltaText && deltaText.includes("knowledge_search")) { + // console.log("🔍 Function call detected in text output:", deltaText); + // console.log("🔍 Original chunk:", JSON.stringify(chunk, null, 2)); + // } + + if (chunk && typeof chunk === "object" && "event" in chunk) { + const event = ( + chunk as { + event: { + payload?: { + event_type?: string; + turn?: { output_message?: { content?: string } }; + }; + }; + } + ).event; + if (event?.payload?.event_type === "turn_complete") { + const content = event?.payload?.turn?.output_message?.content; + if (content && content.includes("knowledge_search")) { + console.log("🔍 Function call found in turn_complete:", content); + } + } + } + + if (deltaText) { + fullContent += deltaText; flushSync(() => { - setMessages(prev => { - const newMessages = [...prev]; - const lastMessage = newMessages[newMessages.length - 1]; - if (lastMessage.role === "assistant") { - lastMessage.content = fullContent; + setCurrentSession(prev => { + if (!prev) return null; + const newMessages = [...prev.messages]; + const last = newMessages[newMessages.length - 1]; + if (last.role === "assistant") { + last.content = fullContent; } - return newMessages; + const updatedSession = { + ...prev, + messages: newMessages, + updatedAt: Date.now(), + }; + // update cache with streaming content + if (fullContent.length % 100 === 0) { + // Only cache every 100 characters + SessionUtils.saveSessionData(prev.agentId, updatedSession); + } + return updatedSession; }); }); } } } catch (err) { + if (err instanceof Error && err.name === "AbortError") { + console.log("Request aborted"); + return; + } + console.error("Error sending message:", err); setError("Failed to send message. Please try again."); - setMessages(prev => prev.slice(0, -1)); + setCurrentSession(prev => + prev + ? { + ...prev, + messages: prev.messages.slice(0, -1), + updatedAt: Date.now(), + } + : prev + ); } finally { setIsGenerating(false); + abortControllerRef.current = null; + // cache final session state after streaming completes + setCurrentSession(prev => { + if (prev) { + SessionUtils.saveSessionData(prev.agentId, prev); + } + return prev; + }); } }; const suggestions = [ @@ -181,69 +1134,760 @@ export default function ChatPlaygroundPage() { content: message.content, createdAt: new Date(), }; - setMessages(prev => [...prev, newMessage]); + setCurrentSession(prev => + prev + ? { + ...prev, + messages: [...prev.messages, newMessage], + updatedAt: Date.now(), + } + : prev + ); handleSubmitWithContent(newMessage.content); }; const clearChat = () => { - setMessages([]); + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + setIsGenerating(false); + } + + setCurrentSession(prev => + prev ? { ...prev, messages: [], updatedAt: Date.now() } : prev + ); setError(null); }; + const handleRAGFileUpload = async (file: File) => { + if (!selectedAgentConfig?.toolgroups || !selectedAgentId) { + setError("No agent selected or agent has no RAG tools configured"); + return; + } + + // find RAG toolgroups that have vector_db_ids configured + const ragToolgroups = selectedAgentConfig.toolgroups.filter(toolgroup => { + if (typeof toolgroup === "object" && toolgroup.name?.includes("rag")) { + return toolgroup.args && "vector_db_ids" in toolgroup.args; + } + return false; + }); + + if (ragToolgroups.length === 0) { + setError("Current agent has no vector databases configured for RAG"); + return; + } + + try { + setError(null); + console.log("Uploading file using RAG tool..."); + + setUploadNotification({ + show: true, + message: `📄 Uploading and indexing "${file.name}"...`, + type: "loading", + }); + + const vectorDbIds = ragToolgroups.flatMap(toolgroup => { + if ( + typeof toolgroup === "object" && + toolgroup.args && + "vector_db_ids" in toolgroup.args + ) { + return toolgroup.args.vector_db_ids as string[]; + } + return []; + }); + + // determine mime type from file extension - this should be in the Llama Stack Client IMO + const getContentType = (filename: string): string => { + const ext = filename.toLowerCase().split(".").pop(); + switch (ext) { + case "pdf": + return "application/pdf"; + case "txt": + return "text/plain"; + case "md": + return "text/markdown"; + case "html": + return "text/html"; + case "csv": + return "text/csv"; + case "json": + return "application/json"; + case "docx": + return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"; + case "doc": + return "application/msword"; + default: + return "application/octet-stream"; + } + }; + + const mimeType = getContentType(file.name); + let fileContent: string; + + // handle text files vs binary files differently + const isTextFile = + mimeType.startsWith("text/") || + mimeType === "application/json" || + mimeType === "text/markdown" || + mimeType === "text/html" || + mimeType === "text/csv"; + + if (isTextFile) { + fileContent = await file.text(); + } else { + // for PDFs and other binary files, create a data URL + // use FileReader for efficient base64 conversion + fileContent = await new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result as string); + reader.onerror = () => reject(reader.error); + reader.readAsDataURL(file); + }); + } + + for (const vectorDbId of vectorDbIds) { + await client.toolRuntime.ragTool.insert({ + documents: [ + { + content: fileContent, + document_id: `${file.name}-${Date.now()}`, + metadata: { + filename: file.name, + file_size: file.size, + uploaded_at: new Date().toISOString(), + agent_id: selectedAgentId, + }, + mime_type: mimeType, + }, + ], + vector_db_id: vectorDbId, + // TODO: parameterize this somewhere, probably in settings + chunk_size_in_tokens: 512, + }); + } + + console.log("✅ File successfully uploaded using RAG tool"); + + setUploadNotification({ + show: true, + message: `📄 File "${file.name}" uploaded and indexed successfully!`, + type: "success", + }); + + setTimeout(() => { + setUploadNotification(prev => ({ ...prev, show: false })); + }, 4000); + } catch (err) { + console.error("Error uploading file using RAG tool:", err); + const errorMessage = + err instanceof Error + ? `Failed to upload file: ${err.message}` + : "Failed to upload file using RAG tool"; + + setUploadNotification({ + show: true, + message: errorMessage, + type: "error", + }); + + setTimeout(() => { + setUploadNotification(prev => ({ ...prev, show: false })); + }, 6000); + } + }; + return ( -
-
-

Chat Playground (Completions)

-
- - + className="ml-2 text-gray-400 hover:text-gray-600" + > + ✕ + + )} +
+
+ )} + + {/* Header */} +
+
+

Agent Session

+
+ {!agentsLoading && agents.length > 0 && ( +
+ + + {selectedAgentId && ( + + )} +
+ )} + + {!agentsLoading && agents.length > 0 && ( + + )} +
+
+
+ {/* Main Two-Column Layout */} +
+ {/* Left Column - Configuration Panel */} +
+

+ Settings +

+ + {/* Model Configuration */} +
+

+ Model Configuration +

+
+
+ + + {modelsError && ( +

{modelsError}

+ )} +
+ +
+ +
+ {(selectedAgentId && + agents.find(a => a.agent_id === selectedAgentId) + ?.agent_config?.instructions) || + "No agent selected"} +
+

+ Instructions are set when creating an agent and cannot be + changed. +

+
+
+
+ + {/* Agent Tools */} +
+

+ Agent Tools +

+
+
+ +
+ {selectedAgentConfig?.toolgroups && + selectedAgentConfig.toolgroups.length > 0 ? ( + selectedAgentConfig.toolgroups.map( + ( + toolgroup: + | string + | { name: string; args: Record }, + index: number + ) => { + const toolName = + typeof toolgroup === "string" + ? toolgroup + : toolgroup.name; + const toolArgs = + typeof toolgroup === "object" ? toolgroup.args : null; + + const isRAGTool = toolName.includes("rag"); + const displayName = isRAGTool ? "RAG Search" : toolName; + const displayIcon = isRAGTool + ? "🔍" + : toolName.includes("search") + ? "🌐" + : "🔧"; + + return ( +
+
+
+ {displayIcon} + + {displayName} + +
+
+ {isRAGTool && toolArgs && toolArgs.vector_db_ids ? ( +
+ + Vector Databases: + +
+ {Array.isArray(toolArgs.vector_db_ids) ? ( + toolArgs.vector_db_ids.map( + (dbId: string, idx: number) => ( + + {dbId} + + ) + ) + ) : ( + + {String(toolArgs.vector_db_ids)} + + )} +
+
+ ) : null} + {!isRAGTool && + toolArgs && + Object.keys(toolArgs).length > 0 && ( +
+ + Configuration: + {" "} + {Object.keys(toolArgs).length} parameter + {Object.keys(toolArgs).length > 1 ? "s" : ""} +
+ )} +
+ ); + } + ) + ) : ( +
+

+ No tools configured +

+

+ This agent only has text generation capabilities +

+
+ )} +
+

+ Tools are configured when creating an agent and provide + additional capabilities like web search, math calculations, or + RAG document retrieval. +

+
+
+
+
+ + {/* Right Column - Chat Interface */} +
+ {error && ( +
+

{error}

+
+ )} + + {!agentsLoading && agents.length === 0 ? ( +
+
+
🦙
+

+ Create an Agent with Llama Stack +

+

+ To get started, create your first agent. Each agent is + configured with specific instructions, models, and tools to + help you with different tasks. +

+ +
+
+ ) : ( + + setCurrentSession(prev => + prev ? { ...prev, messages, updatedAt: Date.now() } : prev + ) + } + onRAGFileUpload={handleRAGFileUpload} + /> + )}
- {modelsError && ( -
-

{modelsError}

+ {/* Create Agent Modal */} + {showCreateAgent && ( +
+ +

Create New Agent

+ +
+
+ + setNewAgentName(e.target.value)} + placeholder="My Custom Agent" + /> +
+ +
+ + +
+ +
+ +