diff --git a/.github/workflows/gha_workflow_llama_stack_tests.yml b/.github/workflows/gha_workflow_llama_stack_tests.yml new file mode 100644 index 000000000..89e5edf71 --- /dev/null +++ b/.github/workflows/gha_workflow_llama_stack_tests.yml @@ -0,0 +1,355 @@ +name: "Run Llama-stack Tests" + +on: + #### Temporarily disable PR runs until tests run as intended within mainline. + #TODO Add this back. + #pull_request_target: + # types: ["opened"] + # branches: + # - 'main' + # paths: + # - 'llama_stack/**/*.py' + # - 'tests/**/*.py' + + workflow_dispatch: + inputs: + runner: + description: 'GHA Runner Scale Set label to run workflow on.' + required: true + default: "llama-stack-gha-runner-gpu" + + checkout_reference: + description: "The branch, tag, or SHA to checkout" + required: true + default: "main" + + debug: + description: 'Run debugging steps?' + required: false + default: "true" + + sleep_time: + description: '[DEBUG] sleep time for debugging' + required: true + default: "0" + + provider_id: + description: 'ID of your provider' + required: true + default: "meta_reference" + + model_id: + description: 'Shorthand name for target model ID (llama_3b or llama_8b)' + required: true + default: "llama_3b" + + model_override_3b: + description: 'Specify shorthand model for ' + required: false + default: "Llama3.2-3B-Instruct" + + model_override_8b: + description: 'Specify shorthand model for ' + required: false + default: "Llama3.1-8B-Instruct" + +env: + # ID used for each test's provider config + PROVIDER_ID: "${{ inputs.provider_id || 'meta_reference' }}" + + # Path to model checkpoints within EFS volume + MODEL_CHECKPOINT_DIR: "/data/llama" + + # Path to directory to run tests from + TESTS_PATH: "${{ github.workspace }}/llama_stack/providers/tests" + + # Keep track of a list of model IDs that are valid to use within pytest fixture marks + AVAILABLE_MODEL_IDs: "llama_3b llama_8b" + + # Shorthand name for model ID, used in pytest fixture marks + MODEL_ID: "${{ inputs.model_id || 'llama_3b' }}" + + # Override the `llama_3b` / `llama_8b' models, else use the default. + LLAMA_3B_OVERRIDE: "${{ inputs.model_override_3b || 'Llama3.2-3B-Instruct' }}" + LLAMA_8B_OVERRIDE: "${{ inputs.model_override_8b || 'Llama3.1-8B-Instruct' }}" + + # Defines which directories in TESTS_PATH to exclude from the test loop + EXCLUDED_DIRS: "__pycache__" + + # Defines the output xml reports generated after a test is run + REPORTS_GEN: "" + +jobs: + execute_workflow: + name: Execute workload on Self-Hosted GPU k8s runner + permissions: + pull-requests: write + defaults: + run: + shell: bash + runs-on: ${{ inputs.runner != '' && inputs.runner || 'llama-stack-gha-runner-gpu' }} + if: always() + steps: + + ############################## + #### INITIAL DEBUG CHECKS #### + ############################## + - name: "[DEBUG] Check content of the EFS mount" + id: debug_efs_volume + continue-on-error: true + if: inputs.debug == 'true' + run: | + echo "========= Content of the EFS mount =============" + ls -la ${{ env.MODEL_CHECKPOINT_DIR }} + + - name: "[DEBUG] Get runner container OS information" + id: debug_os_info + if: ${{ inputs.debug == 'true' }} + run: | + cat /etc/os-release + + - name: "[DEBUG] Print environment variables" + id: debug_env_vars + if: ${{ inputs.debug == 'true' }} + run: | + echo "PROVIDER_ID = ${PROVIDER_ID}" + echo "MODEL_CHECKPOINT_DIR = ${MODEL_CHECKPOINT_DIR}" + echo "AVAILABLE_MODEL_IDs = ${AVAILABLE_MODEL_IDs}" + echo "MODEL_ID = ${MODEL_ID}" + echo "LLAMA_3B_OVERRIDE = ${LLAMA_3B_OVERRIDE}" + echo "LLAMA_8B_OVERRIDE = ${LLAMA_8B_OVERRIDE}" + echo "EXCLUDED_DIRS = ${EXCLUDED_DIRS}" + echo "REPORTS_GEN = ${REPORTS_GEN}" + + ############################ + #### MODEL INPUT CHECKS #### + ############################ + + - name: "Check if env.model_id is valid" + id: check_model_id + run: | + if [[ " ${AVAILABLE_MODEL_IDs[@]} " =~ " ${MODEL_ID} " ]]; then + echo "Model ID '${MODEL_ID}' is valid." + else + echo "Model ID '${MODEL_ID}' is invalid. Terminating workflow." + exit 1 + fi + + ####################### + #### CODE CHECKOUT #### + ####################### + - name: "Checkout 'meta-llama/llama-stack' repository" + id: checkout_repo + uses: actions/checkout@v4 + with: + ref: ${{ inputs.branch }} + + - name: "[DEBUG] Content of the repository after checkout" + id: debug_content_after_checkout + if: ${{ inputs.debug == 'true' }} + run: | + ls -la ${GITHUB_WORKSPACE} + + ########################################################## + #### OPTIONAL SLEEP DEBUG #### + # # + # Use to "exec" into the test k8s POD and run tests # + # manually to identify what dependencies are being used. # + # # + ########################################################## + - name: "[DEBUG] sleep" + id: debug_sleep + if: ${{ inputs.debug == 'true' && inputs.sleep_time != '' }} + run: | + sleep ${{ inputs.sleep_time }} + + ############################ + #### UPDATE SYSTEM PATH #### + ############################ + - name: "Update path: execute" + id: path_update_exec + run: | + # .local/bin is needed for certain libraries installed below to be recognized + # when calling their executable to install sub-dependencies + mkdir -p ${HOME}/.local/bin + echo "${HOME}/.local/bin" >> "$GITHUB_PATH" + + ##################################### + #### UPDATE CHECKPOINT DIRECTORY #### + ##################################### + - name: "Update checkpoint directory" + id: checkpoint_update + run: | + echo "Checkpoint directory: ${MODEL_CHECKPOINT_DIR}/$LLAMA_3B_OVERRIDE" + if [ "${MODEL_ID}" = "llama_3b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" ]; then + echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" >> "$GITHUB_ENV" + elif [ "${MODEL_ID}" = "llama_8b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" ]; then + echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" >> "$GITHUB_ENV" + else + echo "MODEL_ID & LLAMA_*B_OVERRIDE are not a valid pairing. Terminating workflow." + exit 1 + fi + + - name: "[DEBUG] Checkpoint update check" + id: debug_checkpoint_update + if: ${{ inputs.debug == 'true' }} + run: | + echo "MODEL_CHECKPOINT_DIR (after update) = ${MODEL_CHECKPOINT_DIR}" + + ################################## + #### DEPENDENCY INSTALLATIONS #### + ################################## + - name: "Installing 'apt' required packages" + id: install_apt + run: | + echo "[STEP] Installing 'apt' required packages" + sudo apt update -y + sudo apt install -y python3 python3-pip npm wget + + - name: "Installing packages with 'curl'" + id: install_curl + run: | + curl -fsSL https://ollama.com/install.sh | sh + + - name: "Installing packages with 'wget'" + id: install_wget + run: | + wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh + chmod +x Miniconda3-latest-Linux-x86_64.sh + ./Miniconda3-latest-Linux-x86_64.sh -b install -c pytorch -c nvidia faiss-gpu=1.9.0 + # Add miniconda3 bin to system path + echo "${HOME}/miniconda3/bin" >> "$GITHUB_PATH" + + - name: "Installing packages with 'npm'" + id: install_npm_generic + run: | + sudo npm install -g junit-merge + + - name: "Installing pip dependencies" + id: install_pip_generic + run: | + echo "[STEP] Installing 'llama-stack' models" + pip install -U pip setuptools + pip install -r requirements.txt + pip install -e . + pip install -U \ + torch torchvision \ + pytest pytest_asyncio \ + fairscale lm-format-enforcer \ + zmq chardet pypdf \ + pandas sentence_transformers together \ + aiosqlite + - name: "Installing packages with conda" + id: install_conda_generic + run: | + conda install -q -c pytorch -c nvidia faiss-gpu=1.9.0 + + ############################################################# + #### TESTING TO BE DONE FOR BOTH PRS AND MANUAL DISPATCH #### + ############################################################# + - name: "Run Tests: Loop" + id: run_tests_loop + working-directory: "${{ github.workspace }}" + run: | + pattern="" + for dir in llama_stack/providers/tests/*; do + if [ -d "$dir" ]; then + dir_name=$(basename "$dir") + if [[ ! " $EXCLUDED_DIRS " =~ " $dir_name " ]]; then + for file in "$dir"/test_*.py; do + test_name=$(basename "$file") + new_file="result-${dir_name}-${test_name}.xml" + if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "${PROVIDER_ID} and ${MODEL_ID}" \ + --junitxml="${{ github.workspace }}/${new_file}"; then + echo "Ran test: ${test_name}" + else + echo "Did NOT run test: ${test_name}" + fi + pattern+="${new_file} " + done + fi + fi + done + echo "REPORTS_GEN=$pattern" >> "$GITHUB_ENV" + + - name: "Test Summary: Merge" + id: test_summary_merge + working-directory: "${{ github.workspace }}" + run: | + echo "Merging the following test result files: ${REPORTS_GEN}" + # Defaults to merging them into 'merged-test-results.xml' + junit-merge ${{ env.REPORTS_GEN }} + + ############################################ + #### AUTOMATIC TESTING ON PULL REQUESTS #### + ############################################ + + #### Run tests #### + + - name: "PR - Run Tests" + id: pr_run_tests + working-directory: "${{ github.workspace }}" + if: github.event_name == 'pull_request_target' + run: | + echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${GITHUB_WORKSPACE} | path: ${{ github.workspace }}" + # (Optional) Add more tests here. + + # Merge test results with 'merged-test-results.xml' from above. + # junit-merge merged-test-results.xml + + #### Create test summary #### + + - name: "PR - Test Summary" + id: pr_test_summary_create + if: github.event_name == 'pull_request_target' + uses: test-summary/action@v2 + with: + paths: "${{ github.workspace }}/merged-test-results.xml" + output: test-summary.md + + - name: "PR - Upload Test Summary" + id: pr_test_summary_upload + if: github.event_name == 'pull_request_target' + uses: actions/upload-artifact@v3 + with: + name: test-summary + path: test-summary.md + + #### Update PR request #### + + - name: "PR - Update comment" + id: pr_update_comment + if: github.event_name == 'pull_request_target' + uses: thollander/actions-comment-pull-request@v2 + with: + filePath: test-summary.md + + ######################## + #### MANUAL TESTING #### + ######################## + + #### Run tests #### + + - name: "Manual - Run Tests: Prep" + id: manual_run_tests + working-directory: "${{ github.workspace }}" + if: github.event_name == 'workflow_dispatch' + run: | + echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${{ github.workspace }}" + + #TODO Use this when collection errors are resolved + # pytest -s -v -m "${PROVIDER_ID} and ${MODEL_ID}" --junitxml="${{ github.workspace }}/merged-test-results.xml" + + # (Optional) Add more tests here. + + # Merge test results with 'merged-test-results.xml' from above. + # junit-merge merged-test-results.xml + + #### Create test summary #### + + - name: "Manual - Test Summary" + id: manual_test_summary + if: always() && github.event_name == 'workflow_dispatch' + uses: test-summary/action@v2 + with: + paths: "${{ github.workspace }}/merged-test-results.xml" diff --git a/README.md b/README.md index 8e57292c3..0dfb1306d 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ Additionally, we have designed every element of the Stack such that APIs as well | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | | Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Cerebras | Single Node | | :heavy_check_mark: | | | | | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | | Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | @@ -95,6 +96,7 @@ Additionally, we have designed every element of the Stack such that APIs as well |:----------------: |:------------------------------------------: |:-----------------------: | | Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) | | Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | +| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) | | TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) | | Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) | diff --git a/distributions/cerebras/build.yaml b/distributions/cerebras/build.yaml new file mode 120000 index 000000000..bccbbcf60 --- /dev/null +++ b/distributions/cerebras/build.yaml @@ -0,0 +1 @@ +../../llama_stack/templates/cerebras/build.yaml \ No newline at end of file diff --git a/distributions/cerebras/compose.yaml b/distributions/cerebras/compose.yaml new file mode 100644 index 000000000..f2e9a6f42 --- /dev/null +++ b/distributions/cerebras/compose.yaml @@ -0,0 +1,16 @@ +services: + llamastack: + image: llamastack/distribution-cerebras + network_mode: "host" + volumes: + - ~/.llama:/root/.llama + - ./run.yaml:/root/llamastack-run-cerebras.yaml + ports: + - "5000:5000" + entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-cerebras.yaml" + deploy: + restart_policy: + condition: on-failure + delay: 3s + max_attempts: 5 + window: 60s diff --git a/distributions/cerebras/run.yaml b/distributions/cerebras/run.yaml new file mode 120000 index 000000000..9f9d20b4b --- /dev/null +++ b/distributions/cerebras/run.yaml @@ -0,0 +1 @@ +../../llama_stack/templates/cerebras/run.yaml \ No newline at end of file diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 36426e862..80468cc73 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -1,4 +1,152 @@ { + "tgi": [ + "aiohttp", + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "huggingface_hub", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "remote-vllm": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "openai", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "vllm-gpu": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "vllm", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "meta-reference-quantized-gpu": [ + "accelerate", + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "fairscale", + "faiss-cpu", + "fastapi", + "fbgemm-gpu", + "fire", + "httpx", + "lm-format-enforcer", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "torch", + "torchao==0.5.0", + "torchvision", + "tqdm", + "transformers", + "uvicorn", + "zmq", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "meta-reference-gpu": [ + "accelerate", + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "fairscale", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "lm-format-enforcer", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "torch", + "torchvision", + "tqdm", + "transformers", + "uvicorn", + "zmq", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], "hf-serverless": [ "aiohttp", "aiosqlite", @@ -54,88 +202,7 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "vllm-gpu": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "vllm", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "remote-vllm": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "openai", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "fireworks": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "fireworks-ai", - "httpx", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "tgi": [ + "ollama": [ "aiohttp", "aiosqlite", "blobfile", @@ -145,10 +212,10 @@ "fastapi", "fire", "httpx", - "huggingface_hub", "matplotlib", "nltk", "numpy", + "ollama", "pandas", "pillow", "psycopg2-binary", @@ -190,100 +257,6 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "meta-reference-gpu": [ - "accelerate", - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "fairscale", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "lm-format-enforcer", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "torch", - "torchvision", - "tqdm", - "transformers", - "uvicorn", - "zmq", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "meta-reference-quantized-gpu": [ - "accelerate", - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "fairscale", - "faiss-cpu", - "fastapi", - "fbgemm-gpu", - "fire", - "httpx", - "lm-format-enforcer", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "torch", - "torchao==0.5.0", - "torchvision", - "tqdm", - "transformers", - "uvicorn", - "zmq", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "ollama": [ - "aiohttp", - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "ollama", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], "hf-endpoint": [ "aiohttp", "aiosqlite", @@ -311,5 +284,58 @@ "uvicorn", "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "fireworks": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "fireworks-ai", + "httpx", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "cerebras": [ + "aiosqlite", + "blobfile", + "cerebras_cloud_sdk", + "chardet", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" ] } diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 090253804..4f220ea1e 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -2291,6 +2291,39 @@ "required": true } } + }, + "/alpha/datasets/unregister": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "Datasets" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnregisterDatasetRequest" + } + } + }, + "required": true + } + } } }, "jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema", @@ -7917,6 +7950,18 @@ "required": [ "model_id" ] + }, + "UnregisterDatasetRequest": { + "type": "object", + "properties": { + "dataset_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "dataset_id" + ] } }, "responses": {} @@ -8529,6 +8574,10 @@ "name": "UnregisterModelRequest", "description": "" }, + { + "name": "UnregisterDatasetRequest", + "description": "" + }, { "name": "UnstructuredLogEvent", "description": "" @@ -8718,6 +8767,7 @@ "URL", "UnregisterMemoryBankRequest", "UnregisterModelRequest", + "UnregisterDatasetRequest", "UnstructuredLogEvent", "UserMessage", "VectorMemoryBank", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 8ffd9fdef..6564ddf3f 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -3253,6 +3253,14 @@ components: required: - model_id type: object + UnregisterDatasetRequest: + additionalProperties: false + properties: + dataset_id: + type: string + required: + - dataset_id + type: object UnstructuredLogEvent: additionalProperties: false properties: @@ -3789,6 +3797,27 @@ paths: description: OK tags: - Datasets + /alpha/datasets/unregister: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UnregisterDatasetRequest' + required: true + responses: + '200': + description: OK + tags: + - Datasets /alpha/eval-tasks/get: get: parameters: @@ -5242,6 +5271,9 @@ tags: - description: name: UnregisterModelRequest +- description: + name: UnregisterDatasetRequest - description: name: UnstructuredLogEvent @@ -5418,6 +5450,7 @@ x-tagGroups: - URL - UnregisterMemoryBankRequest - UnregisterModelRequest + - UnregisterDatasetRequest - UnstructuredLogEvent - UserMessage - VectorMemoryBank diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index deb475b16..4dfafa5b8 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -66,129 +66,265 @@ llama stack build --list-templates ``` ``` -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| Template Name | Providers | Description | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| hf-serverless | { | Like local, but use Hugging Face Inference API (serverless) for running LLM | -| | "inference": "remote::hf::serverless", | inference. | -| | "memory": "meta-reference", | See https://hf.co/docs/api-inference. | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| together | { | Use Together.ai for running LLM inference | -| | "inference": "remote::together", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::weaviate" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| fireworks | { | Use Fireworks.ai for running LLM inference | -| | "inference": "remote::fireworks", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::weaviate", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| databricks | { | Use Databricks for running LLM inference | -| | "inference": "remote::databricks", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| sambanova | { | Use SambaNova.ai for running LLM inference | -| | "inference": "remote::sambanova", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| vllm | { | Like local, but use vLLM for running LLM inference | -| | "inference": "vllm", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| tgi | { | Use TGI for running LLM inference | -| | "inference": "remote::tgi", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| bedrock | { | Use Amazon Bedrock APIs. | -| | "inference": "remote::bedrock", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| meta-reference-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs | -| | "inference": "meta-reference", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| meta-reference-quantized-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs | -| | "inference": "meta-reference-quantized", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| ollama | { | Use ollama for running LLM inference | -| | "inference": "remote::ollama", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| hf-endpoint | { | Like local, but use Hugging Face Inference Endpoints for running LLM inference. | -| | "inference": "remote::hf::endpoint", | See https://hf.co/docs/api-endpoints. | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| Template Name | Providers | Description | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| tgi | { | Use (an external) TGI server for running LLM inference | +| | "inference": [ | | +| | "remote::tgi" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| remote-vllm | { | Use (an external) vLLM server for running LLM inference | +| | "inference": [ | | +| | "remote::vllm" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| vllm-gpu | { | Use a built-in vLLM engine for running LLM inference | +| | "inference": [ | | +| | "inline::vllm" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| meta-reference-quantized-gpu | { | Use Meta Reference with fp8, int4 quantization for running LLM inference | +| | "inference": [ | | +| | "inline::meta-reference-quantized" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| meta-reference-gpu | { | Use Meta Reference for running LLM inference | +| | "inference": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| hf-serverless | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference | +| | "inference": [ | | +| | "remote::hf::serverless" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| together | { | Use Together.AI for running LLM inference | +| | "inference": [ | | +| | "remote::together" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| ollama | { | Use (an external) Ollama server for running LLM inference | +| | "inference": [ | | +| | "remote::ollama" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| bedrock | { | Use AWS Bedrock for running LLM inference and safety | +| | "inference": [ | | +| | "remote::bedrock" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "remote::bedrock" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| hf-endpoint | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference | +| | "inference": [ | | +| | "remote::hf::endpoint" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| fireworks | { | Use Fireworks.AI for running LLM inference | +| | "inference": [ | | +| | "remote::fireworks" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| cerebras | { | Use Cerebras for running LLM inference | +| | "inference": [ | | +| | "remote::cerebras" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "memory": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| cerebras | { | Use SambaNova.ai for running LLM inference | +| | "inference": [ | | +| | "remote::sambanova" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "memory": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ ``` You may then pick a template to build your distribution with providers fitted to your liking. diff --git a/docs/source/distributions/importing_as_library.md b/docs/source/distributions/importing_as_library.md index 815660fd4..7e15062df 100644 --- a/docs/source/distributions/importing_as_library.md +++ b/docs/source/distributions/importing_as_library.md @@ -21,7 +21,7 @@ print(response) ```python response = await client.inference.chat_completion( messages=[UserMessage(content="What is the capital of France?", role="user")], - model="Llama3.1-8B-Instruct", + model_id="Llama3.1-8B-Instruct", stream=False, ) print("\nChat completion response:") diff --git a/docs/source/distributions/self_hosted_distro/cerebras.md b/docs/source/distributions/self_hosted_distro/cerebras.md new file mode 100644 index 000000000..08b35809a --- /dev/null +++ b/docs/source/distributions/self_hosted_distro/cerebras.md @@ -0,0 +1,61 @@ +# Cerebras Distribution + +The `llamastack/distribution-cerebras` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| inference | `remote::cerebras` | +| memory | `inline::meta-reference` | +| safety | `inline::llama-guard` | +| telemetry | `inline::meta-reference` | + + +### Environment Variables + +The following environment variables can be configured: + +- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) +- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``) + +### Models + +The following models are available by default: + +- `meta-llama/Llama-3.1-8B-Instruct (llama3.1-8b)` +- `meta-llama/Llama-3.1-70B-Instruct (llama3.1-70b)` + + +### Prerequisite: API Keys + +Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/). + + +## Running Llama Stack with Cerebras + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-cerebras \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template cerebras --image-type conda +llama stack run ./run.yaml \ + --port 5001 \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md index 0eb245483..9f81d9329 100644 --- a/docs/source/distributions/self_hosted_distro/ollama.md +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -118,9 +118,9 @@ llama stack run ./run-with-safety.yaml \ ### (Optional) Update Model Serving Configuration -> [!NOTE] -> Please check the [OLLAMA_SUPPORTED_MODELS](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers.remote/inference/ollama/ollama.py) for the supported Ollama models. - +```{note} +Please check the [model_aliases](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) variable for supported Ollama models. +``` To serve a new model with `ollama` ```bash diff --git a/docs/source/index.md b/docs/source/index.md index 1dc65d62b..dccc809c7 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -45,6 +45,7 @@ Llama Stack already has a number of "adapters" available for some popular Infere | **API Provider** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | | Meta Reference | Single Node | Y | Y | Y | Y | Y | +| Cerebras | Single Node | | Y | | | | | Fireworks | Hosted | Y | Y | Y | | | | AWS Bedrock | Hosted | | Y | | Y | | | Together | Hosted | Y | Y | | Y | | diff --git a/docs/source/references/llama_stack_client_cli_reference.md b/docs/source/references/llama_stack_client_cli_reference.md index d3835e488..b35aa189d 100644 --- a/docs/source/references/llama_stack_client_cli_reference.md +++ b/docs/source/references/llama_stack_client_cli_reference.md @@ -27,8 +27,6 @@ $ llama-stack-client configure Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5000 ``` -## Provider Commands - ### `llama-stack-client providers list` ```bash $ llama-stack-client providers list @@ -119,8 +117,25 @@ $ llama-stack-client memory_banks list +--------------+----------------+--------+-------------------+------------------------+--------------------------+ ``` -## Shield Management +### `llama-stack-client memory_banks register` +```bash +$ llama-stack-client memory_banks register --type [--provider-id ] [--provider-memory-bank-id ] [--chunk-size ] [--embedding-model ] [--overlap-size ] +``` +Options: +- `--type`: Required. Type of memory bank. Choices: "vector", "keyvalue", "keyword", "graph" +- `--provider-id`: Optional. Provider ID for the memory bank +- `--provider-memory-bank-id`: Optional. Provider's memory bank ID +- `--chunk-size`: Optional. Chunk size in tokens (for vector type). Default: 512 +- `--embedding-model`: Optional. Embedding model (for vector type). Default: "all-MiniLM-L6-v2" +- `--overlap-size`: Optional. Overlap size in tokens (for vector type). Default: 64 + +### `llama-stack-client memory_banks unregister` +```bash +$ llama-stack-client memory_banks unregister +``` + +## Shield Management ### `llama-stack-client shields list` ```bash $ llama-stack-client shields list @@ -134,16 +149,51 @@ $ llama-stack-client shields list +--------------+----------+----------------+-------------+ ``` -## Evaluation Tasks +### `llama-stack-client shields register` +```bash +$ llama-stack-client shields register --shield-id [--provider-id ] [--provider-shield-id ] [--params ] +``` + +Options: +- `--shield-id`: Required. ID of the shield +- `--provider-id`: Optional. Provider ID for the shield +- `--provider-shield-id`: Optional. Provider's shield ID +- `--params`: Optional. JSON configuration parameters for the shield + +## Eval Task Management ### `llama-stack-client eval_tasks list` ```bash -$ llama-stack-client eval run_benchmark --num-examples 10 --output-dir ./ --eval-task-config ~/eval_task_config.json +$ llama-stack-client eval_tasks list ``` -where `eval_task_config.json` is the path to the eval task config file in JSON format. An example eval_task_config +### `llama-stack-client eval_tasks register` +```bash +$ llama-stack-client eval_tasks register --eval-task-id --dataset-id --scoring-functions [ ...] [--provider-id ] [--provider-eval-task-id ] [--metadata ] ``` -$ cat ~/eval_task_config.json + +Options: +- `--eval-task-id`: Required. ID of the eval task +- `--dataset-id`: Required. ID of the dataset to evaluate +- `--scoring-functions`: Required. One or more scoring functions to use for evaluation +- `--provider-id`: Optional. Provider ID for the eval task +- `--provider-eval-task-id`: Optional. Provider's eval task ID +- `--metadata`: Optional. Metadata for the eval task in JSON format + +## Eval execution +### `llama-stack-client eval run-benchmark` +```bash +$ llama-stack-client eval run-benchmark [ ...] --eval-task-config --output-dir [--num-examples ] [--visualize] +``` + +Options: +- `--eval-task-config`: Required. Path to the eval task config file in JSON format +- `--output-dir`: Required. Path to the directory where evaluation results will be saved +- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging) +- `--visualize`: Optional flag. If set, visualizes evaluation results after completion + +Example eval_task_config.json: +```json { "type": "benchmark", "eval_candidate": { @@ -160,3 +210,14 @@ $ cat ~/eval_task_config.json } } ``` + +### `llama-stack-client eval run-scoring` +```bash +$ llama-stack-client eval run-scoring --eval-task-config --output-dir [--num-examples ] [--visualize] +``` + +Options: +- `--eval-task-config`: Required. Path to the eval task config file in JSON format +- `--output-dir`: Required. Path to the directory where scoring results will be saved +- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging) +- `--visualize`: Optional flag. If set, visualizes scoring results after completion diff --git a/docs/zero_to_hero_guide/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb b/docs/zero_to_hero_guide/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb index e9bff5f33..8e3949e94 100644 --- a/docs/zero_to_hero_guide/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb +++ b/docs/zero_to_hero_guide/Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb @@ -71,7 +71,7 @@ } ], "source": [ - "!pip install llama-stack-client" + "!pip install llama-stack-client==0.0.50" ] }, { diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 25de35497..d2243c96f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -23,6 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated +from llama_stack.distribution.tracing import trace_protocol from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.common.deployment_types import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -418,6 +419,7 @@ class AgentStepResponse(BaseModel): @runtime_checkable +@trace_protocol class Agents(Protocol): @webmethod(route="/agents/create") async def create_agent( diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index c5052877a..22acc3211 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -37,3 +37,8 @@ class DatasetIO(Protocol): page_token: Optional[str] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: ... + + @webmethod(route="/datasetio/append-rows", method="POST") + async def append_rows( + self, dataset_id: str, rows: List[Dict[str, Any]] + ) -> None: ... diff --git a/llama_stack/apis/datasets/client.py b/llama_stack/apis/datasets/client.py index 9e5891e74..c379a49fb 100644 --- a/llama_stack/apis/datasets/client.py +++ b/llama_stack/apis/datasets/client.py @@ -78,6 +78,21 @@ class DatasetsClient(Datasets): return [DatasetDefWithProvider(**x) for x in response.json()] + async def unregister_dataset( + self, + dataset_id: str, + ) -> None: + async with httpx.AsyncClient() as client: + response = await client.delete( + f"{self.base_url}/datasets/unregister", + params={ + "dataset_id": dataset_id, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + async def run_main(host: str, port: int): client = DatasetsClient(f"http://{host}:{port}") diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 2ab958782..e1ac4af21 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -64,3 +64,9 @@ class Datasets(Protocol): @webmethod(route="/datasets/list", method="GET") async def list_datasets(self) -> List[Dataset]: ... + + @webmethod(route="/datasets/unregister", method="POST") + async def unregister_dataset( + self, + dataset_id: str, + ) -> None: ... diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 5aadd97c7..85b29a147 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -21,6 +21,8 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.distribution.tracing import trace_protocol + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 @@ -220,6 +222,7 @@ class ModelStore(Protocol): @runtime_checkable +@trace_protocol class Inference(Protocol): model_store: ModelStore diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index 48b6e2241..b75df8a1a 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -16,6 +16,7 @@ from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.distribution.tracing import trace_protocol @json_schema_type @@ -43,6 +44,7 @@ class MemoryBankStore(Protocol): @runtime_checkable +@trace_protocol class Memory(Protocol): memory_bank_store: MemoryBankStore diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 1b16af330..0b8b2563f 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -20,6 +20,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.distribution.tracing import trace_protocol @json_schema_type @@ -129,6 +130,7 @@ class MemoryBankInput(BaseModel): @runtime_checkable +@trace_protocol class MemoryBanks(Protocol): @webmethod(route="/memory-banks/list", method="GET") async def list_memory_banks(self) -> List[MemoryBank]: ... diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index cbd6265e2..2c0f1ee21 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.distribution.tracing import trace_protocol class CommonModelFields(BaseModel): @@ -43,6 +44,7 @@ class ModelInput(CommonModelFields): @runtime_checkable +@trace_protocol class Models(Protocol): @webmethod(route="/models/list", method="GET") async def list_models(self) -> List[Model]: ... diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 724f8dc96..41058f107 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -10,6 +10,8 @@ from typing import Any, Dict, List, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel +from llama_stack.distribution.tracing import trace_protocol + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 @@ -43,6 +45,7 @@ class ShieldStore(Protocol): @runtime_checkable +@trace_protocol class Safety(Protocol): shield_store: ShieldStore diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 5ee444f68..b28605727 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.distribution.tracing import trace_protocol class CommonShieldFields(BaseModel): @@ -38,6 +39,7 @@ class ShieldInput(CommonShieldFields): @runtime_checkable +@trace_protocol class Shields(Protocol): @webmethod(route="/shields/list", method="GET") async def list_shields(self) -> List[Shield]: ... diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 31f64733b..2ff783c46 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -6,12 +6,24 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +# Add this constant near the top of the file, after the imports +DEFAULT_TTL_DAYS = 7 + @json_schema_type class SpanStatus(Enum): @@ -29,6 +41,11 @@ class Span(BaseModel): end_time: Optional[datetime] = None attributes: Optional[Dict[str, Any]] = Field(default_factory=dict) + def set_attribute(self, key: str, value: Any): + if self.attributes is None: + self.attributes = {} + self.attributes[key] = value + @json_schema_type class Trace(BaseModel): @@ -123,10 +140,49 @@ Event = Annotated[ ] +@json_schema_type +class EvalTrace(BaseModel): + session_id: str + step: str + input: str + output: str + expected_output: str + + +@json_schema_type +class SpanWithChildren(Span): + children: List["SpanWithChildren"] = Field(default_factory=list) + status: Optional[SpanStatus] = None + + +@json_schema_type +class QueryCondition(BaseModel): + key: str + op: Literal["eq", "ne", "gt", "lt"] + value: Any + + @runtime_checkable class Telemetry(Protocol): - @webmethod(route="/telemetry/log-event") - async def log_event(self, event: Event) -> None: ... - @webmethod(route="/telemetry/get-trace", method="GET") - async def get_trace(self, trace_id: str) -> Trace: ... + @webmethod(route="/telemetry/log-event") + async def log_event( + self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400 + ) -> None: ... + + @webmethod(route="/telemetry/query-traces", method="POST") + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: ... + + @webmethod(route="/telemetry/get-span-tree", method="POST") + async def get_span_tree( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5a62b6d64..5b75a525b 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -222,6 +222,12 @@ class DatasetIORouter(DatasetIO): filter_condition=filter_condition, ) + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + return await self.routing_table.get_provider_impl(dataset_id).append_rows( + dataset_id=dataset_id, + rows=rows, + ) + class ScoringRouter(Scoring): def __init__( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 4df693b26..2fb5a5e1c 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: return await p.unregister_memory_bank(obj.identifier) elif api == Api.inference: return await p.unregister_model(obj.identifier) + elif api == Api.datasetio: + return await p.unregister_dataset(obj.identifier) else: raise ValueError(f"Unregister not supported for {api}") @@ -354,6 +356,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): ) await self.register_object(dataset) + async def unregister_dataset(self, dataset_id: str) -> None: + dataset = await self.get_dataset(dataset_id) + if dataset is None: + raise ValueError(f"Dataset {dataset_id} not found") + await self.unregister_object(dataset) + class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> List[ScoringFn]: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8116e2b39..4ae1854df 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -43,9 +43,9 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) -from llama_stack.providers.inline.meta_reference.telemetry.console import ( - ConsoleConfig, - ConsoleTelemetryImpl, +from llama_stack.providers.inline.telemetry.meta_reference import ( + TelemetryAdapter, + TelemetryConfig, ) from .endpoints import get_all_api_endpoints @@ -290,7 +290,7 @@ def main(): if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: - setup_logger(ConsoleTelemetryImpl(ConsoleConfig())) + setup_logger(TelemetryAdapter(TelemetryConfig())) all_endpoints = get_all_api_endpoints() diff --git a/llama_stack/distribution/tracing.py b/llama_stack/distribution/tracing.py new file mode 100644 index 000000000..ea663ec89 --- /dev/null +++ b/llama_stack/distribution/tracing.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import inspect +import json +from functools import wraps +from typing import Any, AsyncGenerator, Callable, Type, TypeVar + +from pydantic import BaseModel + +from llama_stack.providers.utils.telemetry import tracing + +T = TypeVar("T") + + +def serialize_value(value: Any) -> str: + """Helper function to serialize values to string representation.""" + try: + if isinstance(value, BaseModel): + return value.model_dump_json() + elif isinstance(value, list) and value and isinstance(value[0], BaseModel): + return json.dumps([item.model_dump_json() for item in value]) + elif hasattr(value, "to_dict"): + return json.dumps(value.to_dict()) + elif isinstance(value, (dict, list, int, float, str, bool)): + return json.dumps(value) + else: + return str(value) + except Exception: + return str(value) + + +def trace_protocol(cls: Type[T]) -> Type[T]: + """ + A class decorator that automatically traces all methods in a protocol/base class + and its inheriting classes. + """ + + def trace_method(method: Callable) -> Callable: + is_async = asyncio.iscoroutinefunction(method) + is_async_gen = inspect.isasyncgenfunction(method) + + def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple: + class_name = self.__class__.__name__ + method_name = method.__name__ + + span_type = ( + "async_generator" if is_async_gen else "async" if is_async else "sync" + ) + span_attributes = { + "class": class_name, + "method": method_name, + "type": span_type, + "args": serialize_value(args), + } + + return class_name, method_name, span_attributes + + @wraps(method) + async def async_gen_wrapper( + self: Any, *args: Any, **kwargs: Any + ) -> AsyncGenerator: + class_name, method_name, span_attributes = create_span_context( + self, *args, **kwargs + ) + + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + try: + count = 0 + async for item in method(self, *args, **kwargs): + yield item + count += 1 + finally: + span.set_attribute("chunk_count", count) + + @wraps(method) + async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + class_name, method_name, span_attributes = create_span_context( + self, *args, **kwargs + ) + + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + try: + result = await method(self, *args, **kwargs) + span.set_attribute("output", serialize_value(result)) + return result + except Exception as e: + span.set_attribute("error", str(e)) + raise + + @wraps(method) + def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + class_name, method_name, span_attributes = create_span_context( + self, *args, **kwargs + ) + + with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + try: + result = method(self, *args, **kwargs) + span.set_attribute("output", serialize_value(result)) + return result + except Exception as e: + raise + + if is_async_gen: + return async_gen_wrapper + elif is_async: + return async_wrapper + else: + return sync_wrapper + + original_init_subclass = getattr(cls, "__init_subclass__", None) + + def __init_subclass__(cls_child, **kwargs): # noqa: N807 + if original_init_subclass: + original_init_subclass(**kwargs) + + for name, method in vars(cls_child).items(): + if inspect.isfunction(method) and not name.startswith("_"): + setattr(cls_child, name, trace_method(method)) # noqa: B010 + + cls.__init_subclass__ = classmethod(__init_subclass__) + + return cls diff --git a/llama_stack/distribution/ui/README.md b/llama_stack/distribution/ui/README.md index a91883067..2cc352c52 100644 --- a/llama_stack/distribution/ui/README.md +++ b/llama_stack/distribution/ui/README.md @@ -2,6 +2,12 @@ [!NOTE] This is a work in progress. +## Prerequisite +- Start up Llama Stack Server +``` +llama stack run +``` + ## Running Streamlit App ``` diff --git a/llama_stack/distribution/ui/app.py b/llama_stack/distribution/ui/app.py index 763b126a7..87a80e235 100644 --- a/llama_stack/distribution/ui/app.py +++ b/llama_stack/distribution/ui/app.py @@ -3,170 +3,54 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -import json - -import pandas as pd - import streamlit as st -from modules.api import LlamaStackEvaluation - -from modules.utils import process_dataset - -EVALUATION_API = LlamaStackEvaluation() - def main(): - # Add collapsible sidebar - with st.sidebar: - # Add collapse button - if "sidebar_state" not in st.session_state: - st.session_state.sidebar_state = True - - if st.session_state.sidebar_state: - st.title("Navigation") - page = st.radio( - "Select a Page", - ["Application Evaluation"], - index=0, - ) - else: - page = "Application Evaluation" # Default page when sidebar is collapsed - - # Main content area - st.title("🦙 Llama Stack Evaluations") - - if page == "Application Evaluation": - application_evaluation_page() - - -def application_evaluation_page(): - # File uploader - uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"]) - - if uploaded_file is None: - st.error("No file uploaded") - return - - # Process uploaded file - df = process_dataset(uploaded_file) - if df is None: - st.error("Error processing file") - return - - # Display dataset information - st.success("Dataset loaded successfully!") - - # Display dataframe preview - st.subheader("Dataset Preview") - st.dataframe(df) - - # Select Scoring Functions to Run Evaluation On - st.subheader("Select Scoring Functions") - scoring_functions = EVALUATION_API.list_scoring_functions() - scoring_functions = {sf.identifier: sf for sf in scoring_functions} - scoring_functions_names = list(scoring_functions.keys()) - selected_scoring_functions = st.multiselect( - "Choose one or more scoring functions", - options=scoring_functions_names, - help="Choose one or more scoring functions.", + # Evaluation pages + application_evaluation_page = st.Page( + "page/evaluations/app_eval.py", + title="Evaluations (Scoring)", + icon="📊", + default=False, + ) + native_evaluation_page = st.Page( + "page/evaluations/native_eval.py", + title="Evaluations (Generation + Scoring)", + icon="📊", + default=False, ) - available_models = EVALUATION_API.list_models() - available_models = [m.identifier for m in available_models] + # Playground pages + chat_page = st.Page( + "page/playground/chat.py", title="Chat", icon="💬", default=True + ) + rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False) - scoring_params = {} - if selected_scoring_functions: - st.write("Selected:") - for scoring_fn_id in selected_scoring_functions: - scoring_fn = scoring_functions[scoring_fn_id] - st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}") - new_params = None - if scoring_fn.params: - new_params = {} - for param_name, param_value in scoring_fn.params.to_dict().items(): - if param_name == "type": - new_params[param_name] = param_value - continue + # Distribution pages + resources_page = st.Page( + "page/distribution/resources.py", title="Resources", icon="🔍", default=False + ) + provider_page = st.Page( + "page/distribution/providers.py", + title="API Providers", + icon="🔍", + default=False, + ) - if param_name == "judge_model": - value = st.selectbox( - f"Select **{param_name}** for {scoring_fn_id}", - options=available_models, - index=0, - key=f"{scoring_fn_id}_{param_name}", - ) - new_params[param_name] = value - else: - value = st.text_area( - f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format", - value=json.dumps(param_value, indent=2), - height=80, - ) - try: - new_params[param_name] = json.loads(value) - except json.JSONDecodeError: - st.error( - f"Invalid JSON for **{param_name}** in {scoring_fn_id}" - ) - - st.json(new_params) - scoring_params[scoring_fn_id] = new_params - - # Add run evaluation button & slider - total_rows = len(df) - num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows) - - if st.button("Run Evaluation"): - progress_text = "Running evaluation..." - progress_bar = st.progress(0, text=progress_text) - rows = df.to_dict(orient="records") - if num_rows < total_rows: - rows = rows[:num_rows] - - # Create separate containers for progress text and results - progress_text_container = st.empty() - results_container = st.empty() - output_res = {} - for i, r in enumerate(rows): - # Update progress - progress = i / len(rows) - progress_bar.progress(progress, text=progress_text) - - # Run evaluation for current row - score_res = EVALUATION_API.run_scoring( - r, - scoring_function_ids=selected_scoring_functions, - scoring_params=scoring_params, - ) - - for k in r.keys(): - if k not in output_res: - output_res[k] = [] - output_res[k].append(r[k]) - - for fn_id in selected_scoring_functions: - if fn_id not in output_res: - output_res[fn_id] = [] - output_res[fn_id].append(score_res.results[fn_id].score_rows[0]) - - # Display current row results using separate containers - progress_text_container.write( - f"Expand to see current processed result ({i+1}/{len(rows)})" - ) - results_container.json( - score_res.to_json(), - expanded=2, - ) - - progress_bar.progress(1.0, text="Evaluation complete!") - - # Display results in dataframe - if output_res: - output_df = pd.DataFrame(output_res) - st.subheader("Evaluation Results") - st.dataframe(output_df) + pg = st.navigation( + { + "Playground": [ + chat_page, + rag_page, + application_evaluation_page, + native_evaluation_page, + ], + "Inspect": [provider_page, resources_page], + }, + expanded=False, + ) + pg.run() if __name__ == "__main__": diff --git a/llama_stack/providers/remote/telemetry/__init__.py b/llama_stack/distribution/ui/modules/__init__.py similarity index 100% rename from llama_stack/providers/remote/telemetry/__init__.py rename to llama_stack/distribution/ui/modules/__init__.py diff --git a/llama_stack/distribution/ui/modules/api.py b/llama_stack/distribution/ui/modules/api.py index 8bcce5b8a..807eab19d 100644 --- a/llama_stack/distribution/ui/modules/api.py +++ b/llama_stack/distribution/ui/modules/api.py @@ -11,7 +11,7 @@ from typing import Optional from llama_stack_client import LlamaStackClient -class LlamaStackEvaluation: +class LlamaStackApi: def __init__(self): self.client = LlamaStackClient( base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:5000"), @@ -23,14 +23,6 @@ class LlamaStackEvaluation: }, ) - def list_scoring_functions(self): - """List all available scoring functions""" - return self.client.scoring_functions.list() - - def list_models(self): - """List all available judge models""" - return self.client.models.list() - def run_scoring( self, row, scoring_function_ids: list[str], scoring_params: Optional[dict] ): @@ -40,3 +32,6 @@ class LlamaStackEvaluation: return self.client.scoring.score( input_rows=[row], scoring_functions=scoring_params ) + + +llama_stack_api = LlamaStackApi() diff --git a/llama_stack/distribution/ui/modules/utils.py b/llama_stack/distribution/ui/modules/utils.py index f8da2e54e..67cce98fa 100644 --- a/llama_stack/distribution/ui/modules/utils.py +++ b/llama_stack/distribution/ui/modules/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 import os import pandas as pd @@ -29,3 +30,13 @@ def process_dataset(file): except Exception as e: st.error(f"Error processing file: {str(e)}") return None + + +def data_url_from_file(file) -> str: + file_content = file.getvalue() + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type = file.type + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url diff --git a/llama_stack/distribution/ui/page/__init__.py b/llama_stack/distribution/ui/page/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/ui/page/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/ui/page/distribution/datasets.py b/llama_stack/distribution/ui/page/distribution/datasets.py new file mode 100644 index 000000000..44e314cde --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/datasets.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def datasets(): + st.header("Datasets") + + datasets_info = { + d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list() + } + + selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys())) + st.json(datasets_info[selected_dataset], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/eval_tasks.py b/llama_stack/distribution/ui/page/distribution/eval_tasks.py new file mode 100644 index 000000000..4957fb178 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/eval_tasks.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def eval_tasks(): + # Eval Tasks Section + st.header("Eval Tasks") + + eval_tasks_info = { + d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list() + } + + selected_eval_task = st.selectbox( + "Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect" + ) + st.json(eval_tasks_info[selected_eval_task], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/memory_banks.py b/llama_stack/distribution/ui/page/distribution/memory_banks.py new file mode 100644 index 000000000..f28010bf2 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/memory_banks.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def memory_banks(): + st.header("Memory Banks") + memory_banks_info = { + m.identifier: m.to_dict() for m in llama_stack_api.client.memory_banks.list() + } + + if len(memory_banks_info) > 0: + selected_memory_bank = st.selectbox( + "Select a memory bank", list(memory_banks_info.keys()) + ) + st.json(memory_banks_info[selected_memory_bank]) + else: + st.info("No memory banks found") diff --git a/llama_stack/distribution/ui/page/distribution/models.py b/llama_stack/distribution/ui/page/distribution/models.py new file mode 100644 index 000000000..70b166f2e --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/models.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def models(): + # Models Section + st.header("Models") + models_info = { + m.identifier: m.to_dict() for m in llama_stack_api.client.models.list() + } + + selected_model = st.selectbox("Select a model", list(models_info.keys())) + st.json(models_info[selected_model]) diff --git a/llama_stack/distribution/ui/page/distribution/providers.py b/llama_stack/distribution/ui/page/distribution/providers.py new file mode 100644 index 000000000..69f6bd771 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/providers.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def providers(): + st.header("🔍 API Providers") + apis_providers_info = llama_stack_api.client.providers.list() + # selected_api = st.selectbox("Select an API", list(apis_providers_info.keys())) + for api in apis_providers_info.keys(): + st.markdown(f"###### {api}") + st.dataframe([p.to_dict() for p in apis_providers_info[api]], width=500) + + +providers() diff --git a/llama_stack/distribution/ui/page/distribution/resources.py b/llama_stack/distribution/ui/page/distribution/resources.py new file mode 100644 index 000000000..6b3ea0e3a --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/resources.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from page.distribution.datasets import datasets +from page.distribution.eval_tasks import eval_tasks +from page.distribution.memory_banks import memory_banks +from page.distribution.models import models +from page.distribution.scoring_functions import scoring_functions +from page.distribution.shields import shields + +from streamlit_option_menu import option_menu + + +def resources_page(): + options = [ + "Models", + "Memory Banks", + "Shields", + "Scoring Functions", + "Datasets", + "Eval Tasks", + ] + icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"] + selected_resource = option_menu( + None, + options, + icons=icons, + orientation="horizontal", + styles={ + "nav-link": { + "font-size": "12px", + }, + }, + ) + if selected_resource == "Eval Tasks": + eval_tasks() + elif selected_resource == "Memory Banks": + memory_banks() + elif selected_resource == "Datasets": + datasets() + elif selected_resource == "Models": + models() + elif selected_resource == "Scoring Functions": + scoring_functions() + elif selected_resource == "Shields": + shields() + + +resources_page() diff --git a/llama_stack/distribution/ui/page/distribution/scoring_functions.py b/llama_stack/distribution/ui/page/distribution/scoring_functions.py new file mode 100644 index 000000000..581ae0db7 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/scoring_functions.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def scoring_functions(): + st.header("Scoring Functions") + + scoring_functions_info = { + s.identifier: s.to_dict() + for s in llama_stack_api.client.scoring_functions.list() + } + + selected_scoring_function = st.selectbox( + "Select a scoring function", list(scoring_functions_info.keys()) + ) + st.json(scoring_functions_info[selected_scoring_function], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/shields.py b/llama_stack/distribution/ui/page/distribution/shields.py new file mode 100644 index 000000000..18bbfc008 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/shields.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + + +def shields(): + # Shields Section + st.header("Shields") + + shields_info = { + s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list() + } + + selected_shield = st.selectbox("Select a shield", list(shields_info.keys())) + st.json(shields_info[selected_shield]) diff --git a/llama_stack/distribution/ui/page/evaluations/__init__.py b/llama_stack/distribution/ui/page/evaluations/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/ui/page/evaluations/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/ui/page/evaluations/app_eval.py b/llama_stack/distribution/ui/page/evaluations/app_eval.py new file mode 100644 index 000000000..5ec47ed45 --- /dev/null +++ b/llama_stack/distribution/ui/page/evaluations/app_eval.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +import pandas as pd +import streamlit as st + +from modules.api import llama_stack_api +from modules.utils import process_dataset + + +def application_evaluation_page(): + + st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙") + st.title("📊 Evaluations (Scoring)") + + # File uploader + uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"]) + + if uploaded_file is None: + st.error("No file uploaded") + return + + # Process uploaded file + df = process_dataset(uploaded_file) + if df is None: + st.error("Error processing file") + return + + # Display dataset information + st.success("Dataset loaded successfully!") + + # Display dataframe preview + st.subheader("Dataset Preview") + st.dataframe(df) + + # Select Scoring Functions to Run Evaluation On + st.subheader("Select Scoring Functions") + scoring_functions = llama_stack_api.client.scoring_functions.list() + scoring_functions = {sf.identifier: sf for sf in scoring_functions} + scoring_functions_names = list(scoring_functions.keys()) + selected_scoring_functions = st.multiselect( + "Choose one or more scoring functions", + options=scoring_functions_names, + help="Choose one or more scoring functions.", + ) + + available_models = llama_stack_api.client.models.list() + available_models = [m.identifier for m in available_models] + + scoring_params = {} + if selected_scoring_functions: + st.write("Selected:") + for scoring_fn_id in selected_scoring_functions: + scoring_fn = scoring_functions[scoring_fn_id] + st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}") + new_params = None + if scoring_fn.params: + new_params = {} + for param_name, param_value in scoring_fn.params.to_dict().items(): + if param_name == "type": + new_params[param_name] = param_value + continue + + if param_name == "judge_model": + value = st.selectbox( + f"Select **{param_name}** for {scoring_fn_id}", + options=available_models, + index=0, + key=f"{scoring_fn_id}_{param_name}", + ) + new_params[param_name] = value + else: + value = st.text_area( + f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format", + value=json.dumps(param_value, indent=2), + height=80, + ) + try: + new_params[param_name] = json.loads(value) + except json.JSONDecodeError: + st.error( + f"Invalid JSON for **{param_name}** in {scoring_fn_id}" + ) + + st.json(new_params) + scoring_params[scoring_fn_id] = new_params + + # Add run evaluation button & slider + total_rows = len(df) + num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows) + + if st.button("Run Evaluation"): + progress_text = "Running evaluation..." + progress_bar = st.progress(0, text=progress_text) + rows = df.to_dict(orient="records") + if num_rows < total_rows: + rows = rows[:num_rows] + + # Create separate containers for progress text and results + progress_text_container = st.empty() + results_container = st.empty() + output_res = {} + for i, r in enumerate(rows): + # Update progress + progress = i / len(rows) + progress_bar.progress(progress, text=progress_text) + + # Run evaluation for current row + score_res = llama_stack_api.run_scoring( + r, + scoring_function_ids=selected_scoring_functions, + scoring_params=scoring_params, + ) + + for k in r.keys(): + if k not in output_res: + output_res[k] = [] + output_res[k].append(r[k]) + + for fn_id in selected_scoring_functions: + if fn_id not in output_res: + output_res[fn_id] = [] + output_res[fn_id].append(score_res.results[fn_id].score_rows[0]) + + # Display current row results using separate containers + progress_text_container.write( + f"Expand to see current processed result ({i+1}/{len(rows)})" + ) + results_container.json( + score_res.to_json(), + expanded=2, + ) + + progress_bar.progress(1.0, text="Evaluation complete!") + + # Display results in dataframe + if output_res: + output_df = pd.DataFrame(output_res) + st.subheader("Evaluation Results") + st.dataframe(output_df) + + +application_evaluation_page() diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py new file mode 100644 index 000000000..b8cc8bfa6 --- /dev/null +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +import pandas as pd + +import streamlit as st + +from modules.api import llama_stack_api + + +def select_eval_task_1(): + # Select Eval Tasks + st.subheader("1. Choose An Eval Task") + eval_tasks = llama_stack_api.client.eval_tasks.list() + eval_tasks = {et.identifier: et for et in eval_tasks} + eval_tasks_names = list(eval_tasks.keys()) + selected_eval_task = st.selectbox( + "Choose an eval task.", + options=eval_tasks_names, + help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.", + ) + with st.expander("View Eval Task"): + st.json(eval_tasks[selected_eval_task], expanded=True) + + st.session_state["selected_eval_task"] = selected_eval_task + st.session_state["eval_tasks"] = eval_tasks + if st.button("Confirm", key="confirm_1"): + st.session_state["selected_eval_task_1_next"] = True + + +def define_eval_candidate_2(): + if not st.session_state.get("selected_eval_task_1_next", None): + return + + st.subheader("2. Define Eval Candidate") + st.info( + """ + Define the configurations for the evaluation candidate model or agent used for generation. + Select "model" if you want to run generation with inference API, or "agent" if you want to run generation with agent API through specifying AgentConfig. + """ + ) + with st.expander("Define Eval Candidate", expanded=True): + # Define Eval Candidate + candidate_type = st.radio("Candidate Type", ["model", "agent"]) + + available_models = llama_stack_api.client.models.list() + available_models = [model.identifier for model in available_models] + selected_model = st.selectbox( + "Choose a model", + available_models, + index=0, + ) + + # Sampling Parameters + st.markdown("##### Sampling Parameters") + strategy = st.selectbox( + "Strategy", + ["greedy", "top_p", "top_k"], + index=0, + ) + temperature = st.slider( + "Temperature", + min_value=0.0, + max_value=1.0, + value=0.0, + step=0.1, + help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable", + ) + top_p = st.slider( + "Top P", + min_value=0.0, + max_value=1.0, + value=0.95, + step=0.1, + ) + max_tokens = st.slider( + "Max Tokens", + min_value=0, + max_value=4096, + value=512, + step=1, + help="The maximum number of tokens to generate", + ) + repetition_penalty = st.slider( + "Repetition Penalty", + min_value=1.0, + max_value=2.0, + value=1.0, + step=0.1, + help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.", + ) + if candidate_type == "model": + eval_candidate = { + "type": "model", + "model": selected_model, + "sampling_params": { + "strategy": strategy, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "repetition_penalty": repetition_penalty, + }, + } + elif candidate_type == "agent": + system_prompt = st.text_area( + "System Prompt", + value="You are a helpful AI assistant.", + help="Initial instructions given to the AI to set its behavior and context", + ) + tools_json = st.text_area( + "Tools Configuration (JSON)", + value=json.dumps( + [ + { + "type": "brave_search", + "engine": "brave", + "api_key": "ENTER_BRAVE_API_KEY_HERE", + } + ] + ), + help="Enter tool configurations in JSON format. Each tool should have a name, description, and parameters.", + height=200, + ) + try: + tools = json.loads(tools_json) + except json.JSONDecodeError: + st.error("Invalid JSON format for tools configuration") + tools = [] + eval_candidate = { + "type": "agent", + "config": { + "model": selected_model, + "instructions": system_prompt, + "tools": tools, + "tool_choice": "auto", + "tool_prompt_format": "json", + "input_shields": [], + "output_shields": [], + "enable_session_persistence": False, + }, + } + st.session_state["eval_candidate"] = eval_candidate + + if st.button("Confirm", key="confirm_2"): + st.session_state["selected_eval_candidate_2_next"] = True + + +def run_evaluation_3(): + if not st.session_state.get("selected_eval_candidate_2_next", None): + return + + st.subheader("3. Run Evaluation") + # Add info box to explain configurations being used + st.info( + """ + Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button. + """ + ) + selected_eval_task = st.session_state["selected_eval_task"] + eval_tasks = st.session_state["eval_tasks"] + eval_candidate = st.session_state["eval_candidate"] + + dataset_id = eval_tasks[selected_eval_task].dataset_id + rows = llama_stack_api.client.datasetio.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + total_rows = len(rows.rows) + # Add number of examples control + num_rows = st.number_input( + "Number of Examples to Evaluate", + min_value=1, + max_value=total_rows, + value=5, + help="Number of examples from the dataset to evaluate. ", + ) + + eval_task_config = { + "type": "benchmark", + "eval_candidate": eval_candidate, + "scoring_params": {}, + } + + with st.expander("View Evaluation Task", expanded=True): + st.json(eval_tasks[selected_eval_task], expanded=True) + with st.expander("View Evaluation Task Configuration", expanded=True): + st.json(eval_task_config, expanded=True) + + # Add run button and handle evaluation + if st.button("Run Evaluation"): + + progress_text = "Running evaluation..." + progress_bar = st.progress(0, text=progress_text) + rows = rows.rows + if num_rows < total_rows: + rows = rows[:num_rows] + + # Create separate containers for progress text and results + progress_text_container = st.empty() + results_container = st.empty() + output_res = {} + for i, r in enumerate(rows): + # Update progress + progress = i / len(rows) + progress_bar.progress(progress, text=progress_text) + # Run evaluation for current row + eval_res = llama_stack_api.client.eval.evaluate_rows( + task_id=selected_eval_task, + input_rows=[r], + scoring_functions=eval_tasks[selected_eval_task].scoring_functions, + task_config=eval_task_config, + ) + + for k in r.keys(): + if k not in output_res: + output_res[k] = [] + output_res[k].append(r[k]) + + for k in eval_res.generations[0].keys(): + if k not in output_res: + output_res[k] = [] + output_res[k].append(eval_res.generations[0][k]) + + for scoring_fn in eval_tasks[selected_eval_task].scoring_functions: + if scoring_fn not in output_res: + output_res[scoring_fn] = [] + output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) + + progress_text_container.write( + f"Expand to see current processed result ({i+1}/{len(rows)})" + ) + results_container.json(eval_res, expanded=2) + + progress_bar.progress(1.0, text="Evaluation complete!") + # Display results in dataframe + if output_res: + output_df = pd.DataFrame(output_res) + st.subheader("Evaluation Results") + st.dataframe(output_df) + + +def native_evaluation_page(): + + st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙") + st.title("📊 Evaluations (Generation + Scoring)") + + select_eval_task_1() + define_eval_candidate_2() + run_evaluation_3() + + +native_evaluation_page() diff --git a/llama_stack/distribution/ui/page/playground/__init__.py b/llama_stack/distribution/ui/page/playground/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/ui/page/playground/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/ui/page/playground/chat.py b/llama_stack/distribution/ui/page/playground/chat.py new file mode 100644 index 000000000..157922d3b --- /dev/null +++ b/llama_stack/distribution/ui/page/playground/chat.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from modules.api import llama_stack_api + +# Sidebar configurations +with st.sidebar: + st.header("Configuration") + available_models = llama_stack_api.client.models.list() + available_models = [model.identifier for model in available_models] + selected_model = st.selectbox( + "Choose a model", + available_models, + index=0, + ) + + temperature = st.slider( + "Temperature", + min_value=0.0, + max_value=1.0, + value=0.0, + step=0.1, + help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable", + ) + + top_p = st.slider( + "Top P", + min_value=0.0, + max_value=1.0, + value=0.95, + step=0.1, + ) + + max_tokens = st.slider( + "Max Tokens", + min_value=0, + max_value=4096, + value=512, + step=1, + help="The maximum number of tokens to generate", + ) + + repetition_penalty = st.slider( + "Repetition Penalty", + min_value=1.0, + max_value=2.0, + value=1.0, + step=0.1, + help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.", + ) + + stream = st.checkbox("Stream", value=True) + system_prompt = st.text_area( + "System Prompt", + value="You are a helpful AI assistant.", + help="Initial instructions given to the AI to set its behavior and context", + ) + + # Add clear chat button to sidebar + if st.button("Clear Chat", use_container_width=True): + st.session_state.messages = [] + st.rerun() + + +# Main chat interface +st.title("🦙 Chat") + + +# Initialize chat history +if "messages" not in st.session_state: + st.session_state.messages = [] + +# Display chat messages +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + +# Chat input +if prompt := st.chat_input("Example: What is Llama Stack?"): + # Add user message to chat history + st.session_state.messages.append({"role": "user", "content": prompt}) + + # Display user message + with st.chat_message("user"): + st.markdown(prompt) + + # Display assistant response + with st.chat_message("assistant"): + message_placeholder = st.empty() + full_response = "" + + response = llama_stack_api.client.inference.chat_completion( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + model_id=selected_model, + stream=stream, + sampling_params={ + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "repetition_penalty": repetition_penalty, + }, + ) + + if stream: + for chunk in response: + if chunk.event.event_type == "progress": + full_response += chunk.event.delta + message_placeholder.markdown(full_response + "▌") + message_placeholder.markdown(full_response) + else: + full_response = response + message_placeholder.markdown(full_response.completion_message.content) + + st.session_state.messages.append( + {"role": "assistant", "content": full_response} + ) diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py new file mode 100644 index 000000000..ffcaf1afd --- /dev/null +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import streamlit as st +from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.event_logger import EventLogger +from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.memory_insert_params import Document + +from modules.api import llama_stack_api +from modules.utils import data_url_from_file + + +def rag_chat_page(): + st.title("🦙 RAG") + + with st.sidebar: + # File/Directory Upload Section + st.subheader("Upload Documents") + uploaded_files = st.file_uploader( + "Upload file(s) or directory", + accept_multiple_files=True, + type=["txt", "pdf", "doc", "docx"], # Add more file types as needed + ) + # Process uploaded files + if uploaded_files: + st.success(f"Successfully uploaded {len(uploaded_files)} files") + # Add memory bank name input field + memory_bank_name = st.text_input( + "Memory Bank Name", + value="rag_bank", + help="Enter a unique identifier for this memory bank", + ) + if st.button("Create Memory Bank"): + documents = [ + Document( + document_id=uploaded_file.name, + content=data_url_from_file(uploaded_file), + ) + for i, uploaded_file in enumerate(uploaded_files) + ] + + providers = llama_stack_api.client.providers.list() + llama_stack_api.client.memory_banks.register( + memory_bank_id=memory_bank_name, # Use the user-provided name + params={ + "embedding_model": "all-MiniLM-L6-v2", + "chunk_size_in_tokens": 512, + "overlap_size_in_tokens": 64, + }, + provider_id=providers["memory"][0].provider_id, + ) + + # insert documents using the custom bank name + llama_stack_api.client.memory.insert( + bank_id=memory_bank_name, # Use the user-provided name + documents=documents, + ) + st.success("Memory bank created successfully!") + + st.subheader("Configure Agent") + # select memory banks + memory_banks = llama_stack_api.client.memory_banks.list() + memory_banks = [bank.identifier for bank in memory_banks] + selected_memory_banks = st.multiselect( + "Select Memory Banks", + memory_banks, + ) + memory_bank_configs = [ + {"bank_id": bank_id, "type": "vector"} for bank_id in selected_memory_banks + ] + + available_models = llama_stack_api.client.models.list() + available_models = [model.identifier for model in available_models] + selected_model = st.selectbox( + "Choose a model", + available_models, + index=0, + ) + system_prompt = st.text_area( + "System Prompt", + value="You are a helpful assistant. ", + help="Initial instructions given to the AI to set its behavior and context", + ) + temperature = st.slider( + "Temperature", + min_value=0.0, + max_value=1.0, + value=0.0, + step=0.1, + help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable", + ) + + top_p = st.slider( + "Top P", + min_value=0.0, + max_value=1.0, + value=0.95, + step=0.1, + ) + + # Add clear chat button to sidebar + if st.button("Clear Chat", use_container_width=True): + st.session_state.messages = [] + st.rerun() + + # Chat Interface + if "messages" not in st.session_state: + st.session_state.messages = [] + + # Display chat history + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + selected_model = llama_stack_api.client.models.list()[0].identifier + + agent_config = AgentConfig( + model=selected_model, + instructions=system_prompt, + sampling_params={ + "strategy": "greedy", + "temperature": temperature, + "top_p": top_p, + }, + tools=[ + { + "type": "memory", + "memory_bank_configs": memory_bank_configs, + "query_generator_config": {"type": "default", "sep": " "}, + "max_tokens_in_context": 4096, + "max_chunks": 10, + } + ], + tool_choice="auto", + tool_prompt_format="json", + input_shields=[], + output_shields=[], + enable_session_persistence=False, + ) + + agent = Agent(llama_stack_api.client, agent_config) + session_id = agent.create_session("rag-session") + + # Chat input + if prompt := st.chat_input("Ask a question about your documents"): + # Add user message to chat history + st.session_state.messages.append({"role": "user", "content": prompt}) + + # Display user message + with st.chat_message("user"): + st.markdown(prompt) + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": prompt, + } + ], + session_id=session_id, + ) + + # Display assistant response + with st.chat_message("assistant"): + retrieval_message_placeholder = st.empty() + message_placeholder = st.empty() + full_response = "" + retrieval_response = "" + for log in EventLogger().log(response): + log.print() + if log.role == "memory_retrieval": + retrieval_response += log.content.replace("====", "").strip() + retrieval_message_placeholder.info(retrieval_response) + else: + full_response += log.content + message_placeholder.markdown(full_response + "▌") + message_placeholder.markdown(full_response) + + st.session_state.messages.append( + {"role": "assistant", "content": full_response} + ) + + +rag_chat_page() diff --git a/llama_stack/distribution/ui/requirements.txt b/llama_stack/distribution/ui/requirements.txt index c03959444..39f2b3d27 100644 --- a/llama_stack/distribution/ui/requirements.txt +++ b/llama_stack/distribution/ui/requirements.txt @@ -1,3 +1,4 @@ streamlit pandas llama-stack-client>=0.0.55 +streamlit-option-menu diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 080204e45..8e89bcc72 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol): class DatasetsProtocolPrivate(Protocol): async def register_dataset(self, dataset: Dataset) -> None: ... + async def unregister_dataset(self, dataset_id: str) -> None: ... + class ScoringFunctionsProtocolPrivate(Protocol): async def list_scoring_functions(self) -> List[ScoringFn]: ... diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8f800ad6f..7df5d3bd4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -144,87 +144,91 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: return await self.storage.create_session(name) - @tracing.span("create_and_execute_turn") async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: - assert request.stream is True, "Non-streaming not supported" + with tracing.span("create_and_execute_turn") as span: + span.set_attribute("session_id", request.session_id) + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("request", request.model_dump_json()) + assert request.stream is True, "Non-streaming not supported" - session_info = await self.storage.get_session_info(request.session_id) - if session_info is None: - raise ValueError(f"Session {request.session_id} not found") + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") - turns = await self.storage.get_session_turns(request.session_id) + turns = await self.storage.get_session_turns(request.session_id) - messages = [] - if self.agent_config.instructions != "": - messages.append(SystemMessage(content=self.agent_config.instructions)) + messages = [] + if self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) - for i, turn in enumerate(turns): - messages.extend(self.turn_to_messages(turn)) + for i, turn in enumerate(turns): + messages.extend(self.turn_to_messages(turn)) - messages.extend(request.messages) + messages.extend(request.messages) - turn_id = str(uuid.uuid4()) - start_time = datetime.now() - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnStartPayload( - turn_id=turn_id, + turn_id = str(uuid.uuid4()) + span.set_attribute("turn_id", turn_id) + start_time = datetime.now() + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnStartPayload( + turn_id=turn_id, + ) ) ) - ) - steps = [] - output_message = None - async for chunk in self.run( - session_id=request.session_id, - turn_id=turn_id, - input_messages=messages, - attachments=request.attachments or [], - sampling_params=self.agent_config.sampling_params, - stream=request.stream, - ): - if isinstance(chunk, CompletionMessage): - log.info( - f"{chunk.role.capitalize()}: {chunk.content}", - ) - output_message = chunk - continue - - assert isinstance( - chunk, AgentTurnResponseStreamChunk - ), f"Unexpected type {type(chunk)}" - event = chunk.event - if ( - event.payload.event_type - == AgentTurnResponseEventType.step_complete.value + steps = [] + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=turn_id, + input_messages=messages, + attachments=request.attachments or [], + sampling_params=self.agent_config.sampling_params, + stream=request.stream, ): - steps.append(event.payload.step_details) + if isinstance(chunk, CompletionMessage): + log.info( + f"{chunk.role.capitalize()}: {chunk.content}", + ) + output_message = chunk + continue - yield chunk + assert isinstance( + chunk, AgentTurnResponseStreamChunk + ), f"Unexpected type {type(chunk)}" + event = chunk.event + if ( + event.payload.event_type + == AgentTurnResponseEventType.step_complete.value + ): + steps.append(event.payload.step_details) - assert output_message is not None + yield chunk - turn = Turn( - turn_id=turn_id, - session_id=request.session_id, - input_messages=request.messages, - output_message=output_message, - started_at=start_time, - completed_at=datetime.now(), - steps=steps, - ) - await self.storage.add_turn_to_session(request.session_id, turn) + assert output_message is not None - chunk = AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, + turn = Turn( + turn_id=turn_id, + session_id=request.session_id, + input_messages=request.messages, + output_message=output_message, + started_at=start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) ) ) - ) - yield chunk + yield chunk async def run( self, @@ -273,7 +277,6 @@ class ChatAgent(ShieldRunnerMixin): yield final_response - @tracing.span("run_shields") async def run_multiple_shields_wrapper( self, turn_id: str, @@ -281,23 +284,47 @@ class ChatAgent(ShieldRunnerMixin): shields: List[str], touchpoint: str, ) -> AsyncGenerator: - if len(shields) == 0: - return + with tracing.span("run_shields") as span: + span.set_attribute("turn_id", turn_id) + span.set_attribute("input", [m.model_dump_json() for m in messages]) + if len(shields) == 0: + span.set_attribute("output", "no shields") + return - step_id = str(uuid.uuid4()) - try: - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.shield_call.value, - step_id=step_id, - metadata=dict(touchpoint=touchpoint), + step_id = str(uuid.uuid4()) + try: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( + step_type=StepType.shield_call.value, + step_id=step_id, + metadata=dict(touchpoint=touchpoint), + ) ) ) - ) - await self.run_multiple_shields(messages, shields) + await self.run_multiple_shields(messages, shields) + + except SafetyException as e: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.shield_call.value, + step_details=ShieldCallStep( + step_id=step_id, + turn_id=turn_id, + violation=e.violation, + ), + ) + ) + ) + span.set_attribute("output", e.violation.model_dump_json()) + + yield CompletionMessage( + content=str(e), + stop_reason=StopReason.end_of_turn, + ) + yield False - except SafetyException as e: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( @@ -305,30 +332,12 @@ class ChatAgent(ShieldRunnerMixin): step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, - violation=e.violation, + violation=None, ), ) ) ) - - yield CompletionMessage( - content=str(e), - stop_reason=StopReason.end_of_turn, - ) - yield False - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, - step_details=ShieldCallStep( - step_id=step_id, - turn_id=turn_id, - violation=None, - ), - ) - ) - ) + span.set_attribute("output", "no violations") async def _run( self, @@ -356,10 +365,15 @@ class ChatAgent(ShieldRunnerMixin): # TODO: find older context from the session and either replace it # or append with a sliding window. this is really a very simplistic implementation - with tracing.span("retrieve_rag_context"): + with tracing.span("retrieve_rag_context") as span: rag_context, bank_ids = await self._retrieve_context( session_id, input_messages, attachments ) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute("output", rag_context) + span.set_attribute("bank_ids", bank_ids) step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -416,7 +430,7 @@ class ChatAgent(ShieldRunnerMixin): content = "" stop_reason = None - with tracing.span("inference"): + with tracing.span("inference") as span: async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, @@ -436,7 +450,6 @@ class ChatAgent(ShieldRunnerMixin): if isinstance(delta, ToolCallDelta): if delta.parse_status == ToolCallParseStatus.success: tool_calls.append(delta.content) - if stream: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -466,6 +479,13 @@ class ChatAgent(ShieldRunnerMixin): if event.stop_reason is not None: stop_reason = event.stop_reason + span.set_attribute("stop_reason", stop_reason) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute( + "output", f"content: {content} tool_calls: {tool_calls}" + ) stop_reason = stop_reason or StopReason.out_of_tokens @@ -549,7 +569,13 @@ class ChatAgent(ShieldRunnerMixin): ) ) - with tracing.span("tool_execution"): + with tracing.span( + "tool_execution", + { + "tool_name": tool_call.tool_name, + "input": message.model_dump_json(), + }, + ) as span: result_messages = await execute_tool_call_maybe( self.tools_dict, [message], @@ -558,6 +584,7 @@ class ChatAgent(ShieldRunnerMixin): len(result_messages) == 1 ), "Currently not supporting multiple messages" result_message = result_messages[0] + span.set_attribute("output", result_message.model_dump_json()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 4de1850ae..736e5d8b9 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,14 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Dict, List, Optional import pandas from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 +import base64 +import os from abc import ABC, abstractmethod from dataclasses import dataclass +from urllib.parse import urlparse from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url @@ -97,6 +100,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): dataset_impl=dataset_impl, ) + async def unregister_dataset(self, dataset_id: str) -> None: + del self.dataset_infos[dataset_id] + async def get_rows_paginated( self, dataset_id: str, @@ -128,3 +134,41 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): total_count=len(rows), next_page_token=str(end), ) + + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + dataset_info = self.dataset_infos.get(dataset_id) + if dataset_info is None: + raise ValueError(f"Dataset with id {dataset_id} not found") + + dataset_impl = dataset_info.dataset_impl + dataset_impl.load() + + new_rows_df = pandas.DataFrame(rows) + new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) + dataset_impl.df = pandas.concat( + [dataset_impl.df, new_rows_df], ignore_index=True + ) + + url = str(dataset_info.dataset_def.url) + parsed_url = urlparse(url) + + if parsed_url.scheme == "file" or not parsed_url.scheme: + file_path = parsed_url.path + os.makedirs(os.path.dirname(file_path), exist_ok=True) + dataset_impl.df.to_csv(file_path, index=False) + elif parsed_url.scheme == "data": + # For data URLs, we need to update the base64-encoded content + if not parsed_url.path.startswith("text/csv;base64,"): + raise ValueError("Data URL must be a base64-encoded CSV") + + csv_buffer = dataset_impl.df.to_csv(index=False) + base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode( + "utf-8" + ) + dataset_info.dataset_def.url = URL( + uri=f"data:text/csv;base64,{base64_content}" + ) + else: + raise ValueError( + f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing." + ) diff --git a/llama_stack/providers/inline/meta_reference/telemetry/__init__.py b/llama_stack/providers/inline/meta_reference/telemetry/__init__.py deleted file mode 100644 index 4a0c2f6ee..000000000 --- a/llama_stack/providers/inline/meta_reference/telemetry/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import ConsoleConfig - - -async def get_provider_impl(config: ConsoleConfig, _deps): - from .console import ConsoleTelemetryImpl - - impl = ConsoleTelemetryImpl(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/meta_reference/telemetry/config.py b/llama_stack/providers/inline/meta_reference/telemetry/config.py deleted file mode 100644 index a1db1d4d8..000000000 --- a/llama_stack/providers/inline/meta_reference/telemetry/config.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum - -from llama_models.schema_utils import json_schema_type - -from pydantic import BaseModel - - -class LogFormat(Enum): - TEXT = "text" - JSON = "json" - - -@json_schema_type -class ConsoleConfig(BaseModel): - log_format: LogFormat = LogFormat.TEXT diff --git a/llama_stack/providers/inline/meta_reference/telemetry/console.py b/llama_stack/providers/inline/meta_reference/telemetry/console.py index d8ef49481..838aaa4e1 100644 --- a/llama_stack/providers/inline/meta_reference/telemetry/console.py +++ b/llama_stack/providers/inline/meta_reference/telemetry/console.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import Optional +from typing import List, Optional from .config import LogFormat @@ -49,8 +49,27 @@ class ConsoleTelemetryImpl(Telemetry): if formatted: print(formatted) - async def get_trace(self, trace_id: str) -> Trace: - raise NotImplementedError() + async def query_traces( + self, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + raise NotImplementedError("Console telemetry does not support trace querying") + + async def get_spans( + self, + span_id: str, + attribute_conditions: Optional[List[QueryCondition]] = None, + attribute_keys_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> SpanWithChildren: + raise NotImplementedError("Console telemetry does not support span querying") COLORS = { diff --git a/llama_stack/providers/inline/scoring/__init__.py b/llama_stack/providers/inline/scoring/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/scoring/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/braintrust/__init__.py b/llama_stack/providers/inline/scoring/braintrust/__init__.py index dc4ea4951..2ddc58bd2 100644 --- a/llama_stack/providers/inline/scoring/braintrust/__init__.py +++ b/llama_stack/providers/inline/scoring/braintrust/__init__.py @@ -5,9 +5,10 @@ # the root directory of this source tree. from typing import Dict -from llama_stack.distribution.datatypes import Api, ProviderSpec from pydantic import BaseModel +from llama_stack.distribution.datatypes import Api, ProviderSpec + from .config import BraintrustScoringConfig diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index cf6e22a29..ee515d588 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -16,6 +16,7 @@ import os from autoevals.llm import Factuality from autoevals.ragas import AnswerCorrectness + from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py index b00b9a7db..0b18bac01 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn llm_as_judge_base = ScoringFn( @@ -14,4 +14,8 @@ llm_as_judge_base = ScoringFn( return_type=NumberType(), provider_id="llm-as-judge", provider_resource_id="llm-as-judge-base", + params=LLMAsJudgeScoringFnParams( + judge_model="meta-llama/Llama-3.1-405B-Instruct", + prompt_template="Enter custom LLM as Judge Prompt Template", + ), ) diff --git a/llama_stack/providers/inline/telemetry/__init__.py b/llama_stack/providers/inline/telemetry/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py new file mode 100644 index 000000000..6213d5536 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from .config import TelemetryConfig, TelemetrySink +from .telemetry import TelemetryAdapter + +__all__ = ["TelemetryConfig", "TelemetryAdapter", "TelemetrySink"] + + +async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]): + impl = TelemetryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/telemetry/meta_reference/config.py b/llama_stack/providers/inline/telemetry/meta_reference/config.py new file mode 100644 index 000000000..0230d24d2 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/config.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, List + +from pydantic import BaseModel, Field + +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + + +class TelemetrySink(str, Enum): + JAEGER = "jaeger" + SQLITE = "sqlite" + CONSOLE = "console" + + +class TelemetryConfig(BaseModel): + otel_endpoint: str = Field( + default="http://localhost:4318/v1/traces", + description="The OpenTelemetry collector endpoint URL", + ) + service_name: str = Field( + default="llama-stack", + description="The service name to use for telemetry", + ) + sinks: List[TelemetrySink] = Field( + default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE], + description="List of telemetry sinks to enable (possible values: jaeger, sqlite, console)", + ) + sqlite_db_path: str = Field( + default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(), + description="The path to the SQLite database to use for storing traces", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "service_name": "${env.OTEL_SERVICE_NAME:llama-stack}", + "sinks": "${env.TELEMETRY_SINKS:['console', 'sqlite']}", + "sqlite_db_path": "${env.SQLITE_DB_PATH:${runtime.base_dir}/trace_store.db}", + } diff --git a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py new file mode 100644 index 000000000..8d6f779e6 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime + +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanProcessor + +# Colors for console output +COLORS = { + "reset": "\033[0m", + "bold": "\033[1m", + "dim": "\033[2m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", +} + + +class ConsoleSpanProcessor(SpanProcessor): + """A SpanProcessor that prints spans to the console with color formatting.""" + + def on_start(self, span: ReadableSpan, parent_context=None) -> None: + """Called when a span starts.""" + timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + + print( + f"{COLORS['dim']}{timestamp}{COLORS['reset']} " + f"{COLORS['magenta']}[START]{COLORS['reset']} " + f"{COLORS['cyan']}{span.name}{COLORS['reset']}" + ) + + def on_end(self, span: ReadableSpan) -> None: + """Called when a span ends.""" + timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + + # Build the span context string + span_context = ( + f"{COLORS['dim']}{timestamp}{COLORS['reset']} " + f"{COLORS['magenta']}[END]{COLORS['reset']} " + f"{COLORS['cyan']}{span.name}{COLORS['reset']} " + ) + + # Add status if not OK + if span.status.status_code != 0: # UNSET or ERROR + status_color = ( + COLORS["red"] if span.status.status_code == 2 else COLORS["yellow"] + ) + span_context += ( + f" {status_color}[{span.status.status_code}]{COLORS['reset']}" + ) + + # Add duration + duration_ms = (span.end_time - span.start_time) / 1e6 + span_context += f" {COLORS['dim']}({duration_ms:.2f}ms){COLORS['reset']}" + + # Print the main span line + print(span_context) + + # Print attributes indented + if span.attributes: + for key, value in span.attributes.items(): + print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}") + + # Print events indented + for event in span.events: + event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime( + "%H:%M:%S.%f" + )[:-3] + print( + f" {COLORS['dim']}{event_time}{COLORS['reset']} " + f"{COLORS['cyan']}[EVENT]{COLORS['reset']} {event.name}" + ) + if event.attributes: + for key, value in event.attributes.items(): + print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}") + + def shutdown(self) -> None: + """Shutdown the processor.""" + pass + + def force_flush(self, timeout_millis: float = None) -> bool: + """Force flush any pending spans.""" + return True diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py new file mode 100644 index 000000000..553dd5000 --- /dev/null +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os +import sqlite3 +import threading +from datetime import datetime, timedelta +from typing import Dict + +from opentelemetry.sdk.trace import SpanProcessor +from opentelemetry.trace import Span + + +class SQLiteSpanProcessor(SpanProcessor): + def __init__(self, conn_string, ttl_days=30): + """Initialize the SQLite span processor with a connection string.""" + self.conn_string = conn_string + self.ttl_days = ttl_days + self.cleanup_task = None + self._thread_local = threading.local() + self._connections: Dict[int, sqlite3.Connection] = {} + self._lock = threading.Lock() + self.setup_database() + + def _get_connection(self) -> sqlite3.Connection: + """Get a thread-specific database connection.""" + thread_id = threading.get_ident() + with self._lock: + if thread_id not in self._connections: + conn = sqlite3.connect(self.conn_string) + self._connections[thread_id] = conn + return self._connections[thread_id] + + def setup_database(self): + """Create the necessary tables if they don't exist.""" + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(self.conn_string), exist_ok=True) + + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS traces ( + trace_id TEXT PRIMARY KEY, + service_name TEXT, + root_span_id TEXT, + start_time TIMESTAMP, + end_time TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS spans ( + span_id TEXT PRIMARY KEY, + trace_id TEXT REFERENCES traces(trace_id), + parent_span_id TEXT, + name TEXT, + start_time TIMESTAMP, + end_time TIMESTAMP, + attributes TEXT, + status TEXT, + kind TEXT + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS span_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + span_id TEXT REFERENCES spans(span_id), + name TEXT, + timestamp TIMESTAMP, + attributes TEXT + ) + """ + ) + + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_traces_created_at + ON traces(created_at) + """ + ) + + conn.commit() + cursor.close() + + # Start periodic cleanup in a separate thread + self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True) + self.cleanup_task.start() + + def _cleanup_old_data(self): + """Delete records older than TTL.""" + try: + conn = self._get_connection() + cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat() + cursor = conn.cursor() + + # Delete old span events + cursor.execute( + """ + DELETE FROM span_events + WHERE span_id IN ( + SELECT span_id FROM spans + WHERE trace_id IN ( + SELECT trace_id FROM traces + WHERE created_at < ? + ) + ) + """, + (cutoff_date,), + ) + + # Delete old spans + cursor.execute( + """ + DELETE FROM spans + WHERE trace_id IN ( + SELECT trace_id FROM traces + WHERE created_at < ? + ) + """, + (cutoff_date,), + ) + + # Delete old traces + cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,)) + + conn.commit() + cursor.close() + except Exception as e: + print(f"Error during cleanup: {e}") + + def _periodic_cleanup(self): + """Run cleanup periodically.""" + import time + + while True: + time.sleep(3600) # Sleep for 1 hour + self._cleanup_old_data() + + def on_start(self, span: Span, parent_context=None): + """Called when a span starts.""" + pass + + def on_end(self, span: Span): + """Called when a span ends. Export the span data to SQLite.""" + try: + conn = self._get_connection() + cursor = conn.cursor() + + trace_id = format(span.get_span_context().trace_id, "032x") + span_id = format(span.get_span_context().span_id, "016x") + service_name = span.resource.attributes.get("service.name", "unknown") + + parent_span_id = None + parent_context = span.parent + if parent_context: + parent_span_id = format(parent_context.span_id, "016x") + + # Insert into traces + cursor.execute( + """ + INSERT INTO traces ( + trace_id, service_name, root_span_id, start_time, end_time + ) VALUES (?, ?, ?, ?, ?) + ON CONFLICT(trace_id) DO UPDATE SET + root_span_id = COALESCE(root_span_id, excluded.root_span_id), + start_time = MIN(excluded.start_time, start_time), + end_time = MAX(excluded.end_time, end_time) + """, + ( + trace_id, + service_name, + (span_id if not parent_span_id else None), + datetime.fromtimestamp(span.start_time / 1e9).isoformat(), + datetime.fromtimestamp(span.end_time / 1e9).isoformat(), + ), + ) + + # Insert into spans + cursor.execute( + """ + INSERT INTO spans ( + span_id, trace_id, parent_span_id, name, + start_time, end_time, attributes, status, + kind + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + span_id, + trace_id, + parent_span_id, + span.name, + datetime.fromtimestamp(span.start_time / 1e9).isoformat(), + datetime.fromtimestamp(span.end_time / 1e9).isoformat(), + json.dumps(dict(span.attributes)), + span.status.status_code.name, + span.kind.name, + ), + ) + + for event in span.events: + cursor.execute( + """ + INSERT INTO span_events ( + span_id, name, timestamp, attributes + ) VALUES (?, ?, ?, ?) + """, + ( + span_id, + event.name, + datetime.fromtimestamp(event.timestamp / 1e9).isoformat(), + json.dumps(dict(event.attributes)), + ), + ) + + conn.commit() + cursor.close() + except Exception as e: + print(f"Error exporting span to SQLite: {e}") + + def shutdown(self): + """Cleanup any resources.""" + with self._lock: + for conn in self._connections.values(): + if conn: + conn.close() + self._connections.clear() + + def force_flush(self, timeout_millis=30000): + """Force export of spans.""" + pass diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py similarity index 67% rename from llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py rename to llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index c9830fd9d..6540a667f 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import threading +from typing import List, Optional from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -16,10 +17,18 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes +from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( + ConsoleSpanProcessor, +) + +from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( + SQLiteSpanProcessor, +) +from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore from llama_stack.apis.telemetry import * # noqa: F403 -from .config import OpenTelemetryConfig +from .config import TelemetryConfig, TelemetrySink _GLOBAL_STORAGE = { "active_spans": {}, @@ -45,8 +54,8 @@ def is_tracing_enabled(tracer): return span.is_recording() -class OpenTelemetryAdapter(Telemetry): - def __init__(self, config: OpenTelemetryConfig): +class TelemetryAdapter(Telemetry): + def __init__(self, config: TelemetryConfig) -> None: self.config = config resource = Resource.create( @@ -57,22 +66,29 @@ class OpenTelemetryAdapter(Telemetry): provider = TracerProvider(resource=resource) trace.set_tracer_provider(provider) - otlp_exporter = OTLPSpanExporter( - endpoint=self.config.otel_endpoint, - ) - span_processor = BatchSpanProcessor(otlp_exporter) - trace.get_tracer_provider().add_span_processor(span_processor) - # Set up metrics - metric_reader = PeriodicExportingMetricReader( - OTLPMetricExporter( + if TelemetrySink.JAEGER in self.config.sinks: + otlp_exporter = OTLPSpanExporter( endpoint=self.config.otel_endpoint, ) - ) - metric_provider = MeterProvider( - resource=resource, metric_readers=[metric_reader] - ) - metrics.set_meter_provider(metric_provider) - self.meter = metrics.get_meter(__name__) + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) + metric_provider = MeterProvider( + resource=resource, metric_readers=[metric_reader] + ) + metrics.set_meter_provider(metric_provider) + self.meter = metrics.get_meter(__name__) + if TelemetrySink.SQLITE in self.config.sinks: + trace.get_tracer_provider().add_span_processor( + SQLiteSpanProcessor(self.config.sqlite_db_path) + ) + self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) + if TelemetrySink.CONSOLE in self.config.sinks: + trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) self._lock = _global_lock async def initialize(self) -> None: @@ -83,15 +99,17 @@ class OpenTelemetryAdapter(Telemetry): trace.get_tracer_provider().shutdown() metrics.get_meter_provider().shutdown() - async def log_event(self, event: Event) -> None: + async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: if isinstance(event, UnstructuredLogEvent): - self._log_unstructured(event) + self._log_unstructured(event, ttl_seconds) elif isinstance(event, MetricEvent): self._log_metric(event) elif isinstance(event, StructuredLogEvent): - self._log_structured(event) + self._log_structured(event, ttl_seconds) + else: + raise ValueError(f"Unknown event type: {event}") - def _log_unstructured(self, event: UnstructuredLogEvent) -> None: + def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: with self._lock: # Use global storage instead of instance storage span_id = string_to_span_id(event.span_id) @@ -104,6 +122,7 @@ class OpenTelemetryAdapter(Telemetry): attributes={ "message": event.message, "severity": event.severity.value, + "__ttl__": ttl_seconds, **event.attributes, }, timestamp=timestamp_ns, @@ -154,11 +173,14 @@ class OpenTelemetryAdapter(Telemetry): ) return _GLOBAL_STORAGE["up_down_counters"][name] - def _log_structured(self, event: StructuredLogEvent) -> None: + def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None: with self._lock: span_id = string_to_span_id(event.span_id) trace_id = string_to_trace_id(event.trace_id) tracer = trace.get_tracer(__name__) + if event.attributes is None: + event.attributes = {} + event.attributes["__ttl__"] = ttl_seconds if isinstance(event.payload, SpanStartPayload): # Check if span already exists to prevent duplicates @@ -170,7 +192,6 @@ class OpenTelemetryAdapter(Telemetry): parent_span_id = string_to_span_id(event.payload.parent_span_id) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) - # Create a new trace context with the trace_id context = trace.Context(trace_id=trace_id) if parent_span: context = trace.set_span_in_context(parent_span, context) @@ -179,14 +200,9 @@ class OpenTelemetryAdapter(Telemetry): name=event.payload.name, context=context, attributes=event.attributes or {}, - start_time=int(event.timestamp.timestamp() * 1e9), ) _GLOBAL_STORAGE["active_spans"][span_id] = span - # Set as current span using context manager - with trace.use_span(span, end_on_exit=False): - pass # Let the span continue beyond this block - elif isinstance(event.payload, SpanEndPayload): span = _GLOBAL_STORAGE["active_spans"].get(span_id) if span: @@ -199,10 +215,33 @@ class OpenTelemetryAdapter(Telemetry): else trace.Status(status_code=trace.StatusCode.ERROR) ) span.set_status(status) - span.end(end_time=int(event.timestamp.timestamp() * 1e9)) - - # Remove from active spans + span.end() _GLOBAL_STORAGE["active_spans"].pop(span_id, None) + else: + raise ValueError(f"Unknown structured log event: {event}") - async def get_trace(self, trace_id: str) -> Trace: - raise NotImplementedError("Trace retrieval not implemented yet") + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + return await self.trace_store.query_traces( + attribute_filters=attribute_filters, + limit=limit, + offset=offset, + order_by=order_by, + ) + + async def get_span_tree( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: + return await self.trace_store.get_materialized_span( + span_id=span_id, + attributes_to_return=attributes_to_return, + max_depth=max_depth, + ) diff --git a/llama_stack/providers/remote/telemetry/sample/__init__.py b/llama_stack/providers/inline/telemetry/sample/__init__.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/__init__.py rename to llama_stack/providers/inline/telemetry/sample/__init__.py diff --git a/llama_stack/providers/remote/telemetry/sample/config.py b/llama_stack/providers/inline/telemetry/sample/config.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/config.py rename to llama_stack/providers/inline/telemetry/sample/config.py diff --git a/llama_stack/providers/remote/telemetry/sample/sample.py b/llama_stack/providers/inline/telemetry/sample/sample.py similarity index 100% rename from llama_stack/providers/remote/telemetry/sample/sample.py rename to llama_stack/providers/inline/telemetry/sample/sample.py diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index b673d8110..e83a631af 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -61,6 +61,17 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.sample.SampleConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="cerebras", + pip_packages=[ + "cerebras_cloud_sdk", + ], + module="llama_stack.providers.remote.inference.cerebras", + config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index ac537e076..a53ad5b94 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -14,9 +14,12 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.telemetry, provider_type="inline::meta-reference", - pip_packages=[], - module="llama_stack.providers.inline.meta_reference.telemetry", - config_class="llama_stack.providers.inline.meta_reference.telemetry.ConsoleConfig", + pip_packages=[ + "opentelemetry-sdk", + "opentelemetry-exporter-otlp-proto-http", + ], + module="llama_stack.providers.inline.telemetry.meta_reference", + config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig", ), remote_provider_spec( api=Api.telemetry, @@ -27,18 +30,4 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig", ), ), - remote_provider_spec( - api=Api.telemetry, - adapter=AdapterSpec( - adapter_type="opentelemetry-jaeger", - pip_packages=[ - "opentelemetry-api", - "opentelemetry-sdk", - "opentelemetry-exporter-jaeger", - "opentelemetry-semantic-conventions", - ], - module="llama_stack.providers.remote.telemetry.opentelemetry", - config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig", - ), - ), ] diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index c2e4506bf..db52270a7 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Dict, List, Optional from llama_stack.apis.datasetio import * # noqa: F403 @@ -64,6 +64,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): ) self.dataset_infos[dataset_def.identifier] = dataset_def + async def unregister_dataset(self, dataset_id: str) -> None: + key = f"{DATASETS_PREFIX}{dataset_id}" + await self.kvstore.delete(key=key) + del self.dataset_infos[dataset_id] + async def get_rows_paginated( self, dataset_id: str, @@ -95,3 +100,22 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): total_count=len(rows), next_page_token=str(end), ) + + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + dataset_def = self.dataset_infos[dataset_id] + loaded_dataset = load_hf_dataset(dataset_def) + + # Convert rows to HF Dataset format + new_dataset = hf_datasets.Dataset.from_list(rows) + + # Concatenate the new rows with existing dataset + updated_dataset = hf_datasets.concatenate_datasets( + [loaded_dataset, new_dataset] + ) + + if dataset_def.metadata.get("path", None): + updated_dataset.push_to_hub(dataset_def.metadata["path"]) + else: + raise NotImplementedError( + "Uploading to URL-based datasets is not supported yet" + ) diff --git a/llama_stack/providers/remote/inference/cerebras/__init__.py b/llama_stack/providers/remote/inference/cerebras/__init__.py new file mode 100644 index 000000000..a24bb2c70 --- /dev/null +++ b/llama_stack/providers/remote/inference/cerebras/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import CerebrasImplConfig + + +async def get_adapter_impl(config: CerebrasImplConfig, _deps): + from .cerebras import CerebrasInferenceAdapter + + assert isinstance( + config, CerebrasImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = CerebrasInferenceAdapter(config) + + await impl.initialize() + + return impl diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py new file mode 100644 index 000000000..65022f85e --- /dev/null +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator + +from cerebras.cloud.sdk import AsyncCerebras + +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_stack.apis.inference import * # noqa: F403 + +from llama_models.datatypes import CoreModelId + +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, + completion_request_to_prompt, +) + +from .config import CerebrasImplConfig + + +model_aliases = [ + build_model_alias( + "llama3.1-8b", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "llama3.1-70b", + CoreModelId.llama3_1_70b_instruct.value, + ), +] + + +class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): + def __init__(self, config: CerebrasImplConfig) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=model_aliases, + ) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + self.client = AsyncCerebras( + base_url=self.config.base_url, api_key=self.config.api_key + ) + + async def initialize(self) -> None: + return + + async def shutdown(self) -> None: + pass + + async def completion( + self, + model_id: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) + request = CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion( + request, + ) + else: + return await self._nonstream_completion(request) + + async def _nonstream_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + params = self._get_params(request) + + r = await self.client.completions.create(**params) + + return process_completion_response(r, self.formatter) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params(request) + + stream = await self.client.completions.create(**params) + + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + params = self._get_params(request) + + r = await self.client.completions.create(**params) + + return process_chat_completion_response(r, self.formatter) + + async def _stream_chat_completion( + self, request: CompletionRequest + ) -> AsyncGenerator: + params = self._get_params(request) + + stream = await self.client.completions.create(**params) + + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + if request.sampling_params and request.sampling_params.top_k: + raise ValueError("`top_k` not supported by Cerebras") + + prompt = "" + if type(request) == ChatCompletionRequest: + prompt = chat_completion_request_to_prompt( + request, self.get_llama_model(request.model), self.formatter + ) + elif type(request) == CompletionRequest: + prompt = completion_request_to_prompt(request, self.formatter) + else: + raise ValueError(f"Unknown request type {type(request)}") + + return { + "model": request.model, + "prompt": prompt, + "stream": request.stream, + **get_sampling_options(request.sampling_params), + } + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py new file mode 100644 index 000000000..9bae6ca4d --- /dev/null +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any, Dict, Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + +DEFAULT_BASE_URL = "https://api.cerebras.ai" + + +@json_schema_type +class CerebrasImplConfig(BaseModel): + base_url: str = Field( + default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), + description="Base URL for the Cerebras API", + ) + api_key: Optional[str] = Field( + default=os.environ.get("CEREBRAS_API_KEY"), + description="Cerebras API Key", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "base_url": DEFAULT_BASE_URL, + "api_key": "${env.CEREBRAS_API_KEY}", + } diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 74c0b8601..f89629afc 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -180,7 +180,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) r = await self.client.generate(**params) - assert isinstance(r, dict) choice = OpenAICompatCompletionChoice( finish_reason=r["done_reason"] if r["done"] else None, diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py b/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py deleted file mode 100644 index 0842afe2d..000000000 --- a/llama_stack/providers/remote/telemetry/opentelemetry/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import OpenTelemetryConfig - - -async def get_adapter_impl(config: OpenTelemetryConfig, _deps): - from .opentelemetry import OpenTelemetryAdapter - - impl = OpenTelemetryAdapter(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py deleted file mode 100644 index 5e9dff1a1..000000000 --- a/llama_stack/providers/remote/telemetry/opentelemetry/config.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict - -from pydantic import BaseModel, Field - - -class OpenTelemetryConfig(BaseModel): - otel_endpoint: str = Field( - default="http://localhost:4318/v1/traces", - description="The OpenTelemetry collector endpoint URL", - ) - service_name: str = Field( - default="llama-stack", - description="The service name to use for telemetry", - ) - - @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: - return { - "otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}", - "service_name": "${env.OTEL_SERVICE_NAME:llama-stack}", - } diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index dd2cbd019..7d88b6115 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -81,6 +81,18 @@ class TestDatasetIO: assert len(response) == 1 assert response[0].identifier == "test_dataset" + with pytest.raises(Exception) as exc_info: + # unregister a dataset that does not exist + await datasets_impl.unregister_dataset("test_dataset2") + + await datasets_impl.unregister_dataset("test_dataset") + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 0 + + with pytest.raises(Exception) as exc_info: + await datasets_impl.unregister_dataset("test_dataset") + @pytest.mark.asyncio async def test_get_rows_paginated(self, datasetio_stack): datasetio_impl, datasets_impl = datasetio_stack diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 65008de66..32061ff62 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.meta_reference import ( ) from llama_stack.providers.remote.inference.bedrock import BedrockConfig +from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig @@ -65,6 +66,21 @@ def inference_meta_reference(inference_model) -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_cerebras() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="cerebras", + provider_type="remote::cerebras", + config=CerebrasImplConfig( + api_key=get_env_or_fail("CEREBRAS_API_KEY"), + ).model_dump(), + ) + ], + ) + + @pytest.fixture(scope="session") def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( @@ -225,6 +241,7 @@ INFERENCE_FIXTURES = [ "vllm_remote", "remote", "bedrock", + "cerebras", "nvidia", "tgi", "sambanova", diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index f0f1d0eb2..aa2f0b413 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -94,6 +94,7 @@ class TestInference: "remote::tgi", "remote::together", "remote::fireworks", + "remote::cerebras", ): pytest.skip("Other inference providers don't support completion() yet") @@ -139,6 +140,7 @@ class TestInference: "remote::tgi", "remote::together", "remote::fireworks", + "remote::cerebras", ): pytest.skip( "Other inference providers don't support structured output in completions yet" @@ -211,7 +213,15 @@ class TestInference: response = await inference_impl.chat_completion( model_id=inference_model, messages=[ - SystemMessage(content="You are a helpful assistant."), + # we include context about Michael Jordan in the prompt so that the test is + # focused on the funtionality of the model and not on the information embedded + # in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons. + SystemMessage( + content=( + "You are a helpful assistant.\n\n" + "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons." + ) + ), UserMessage(content="Please give me information about Michael Jordan."), ], stream=False, diff --git a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py new file mode 100644 index 000000000..ed1343e0b --- /dev/null +++ b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import List, Optional, Protocol + +import aiosqlite + +from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace + + +class TraceStore(Protocol): + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: ... + + async def get_materialized_span( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: ... + + +class SQLiteTraceStore(TraceStore): + def __init__(self, conn_string: str): + self.conn_string = conn_string + + async def query_traces( + self, + attribute_filters: Optional[List[QueryCondition]] = None, + limit: Optional[int] = 100, + offset: Optional[int] = 0, + order_by: Optional[List[str]] = None, + ) -> List[Trace]: + + def build_where_clause() -> tuple[str, list]: + if not attribute_filters: + return "", [] + + ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"} + + conditions = [ + f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op]} ?" + for condition in attribute_filters + ] + params = [condition.value for condition in attribute_filters] + where_clause = " WHERE " + " AND ".join(conditions) + return where_clause, params + + def build_order_clause() -> str: + if not order_by: + return "" + + order_clauses = [] + for field in order_by: + desc = field.startswith("-") + clean_field = field[1:] if desc else field + order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}") + return " ORDER BY " + ", ".join(order_clauses) + + # Build the main query + base_query = """ + WITH matching_traces AS ( + SELECT DISTINCT t.trace_id + FROM traces t + JOIN spans s ON t.trace_id = s.trace_id + {where_clause} + ), + filtered_traces AS ( + SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time + FROM matching_traces mt + JOIN traces t ON mt.trace_id = t.trace_id + LEFT JOIN spans s ON t.trace_id = s.trace_id + {order_clause} + ) + SELECT DISTINCT trace_id, root_span_id, start_time, end_time + FROM filtered_traces + LIMIT {limit} OFFSET {offset} + """ + + where_clause, params = build_where_clause() + query = base_query.format( + where_clause=where_clause, + order_clause=build_order_clause(), + limit=limit, + offset=offset, + ) + + # Execute query and return results + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + return [ + Trace( + trace_id=row["trace_id"], + root_span_id=row["root_span_id"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + ) + for row in rows + ] + + async def get_materialized_span( + self, + span_id: str, + attributes_to_return: Optional[List[str]] = None, + max_depth: Optional[int] = None, + ) -> SpanWithChildren: + # Build the attributes selection + attributes_select = "s.attributes" + if attributes_to_return: + json_object = ", ".join( + f"'{key}', json_extract(s.attributes, '$.{key}')" + for key in attributes_to_return + ) + attributes_select = f"json_object({json_object})" + + # SQLite CTE query with filtered attributes + query = f""" + WITH RECURSIVE span_tree AS ( + SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes + FROM spans s + WHERE s.span_id = ? + + UNION ALL + + SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes + FROM spans s + JOIN span_tree st ON s.parent_span_id = st.span_id + WHERE (? IS NULL OR st.depth < ?) + ) + SELECT * + FROM span_tree + ORDER BY depth, start_time + """ + + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor: + rows = await cursor.fetchall() + + if not rows: + raise ValueError(f"Span {span_id} not found") + + # Build span tree + spans_by_id = {} + root_span = None + + for row in rows: + span = SpanWithChildren( + span_id=row["span_id"], + trace_id=row["trace_id"], + parent_span_id=row["parent_span_id"], + name=row["name"], + start_time=datetime.fromisoformat(row["start_time"]), + end_time=datetime.fromisoformat(row["end_time"]), + attributes=json.loads(row["filtered_attributes"]), + status=row["status"].lower(), + children=[], + ) + + spans_by_id[span.span_id] = span + + if span.span_id == span_id: + root_span = span + elif span.parent_span_id in spans_by_id: + spans_by_id[span.parent_span_id].children.append(span) + + return root_span diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index b53dc0df9..54558afdc 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -69,7 +69,7 @@ class TraceContext: self.logger = logger self.trace_id = trace_id - def push_span(self, name: str, attributes: Dict[str, Any] = None): + def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span: current_span = self.get_current_span() span = Span( span_id=generate_short_uuid(), @@ -94,6 +94,7 @@ class TraceContext: ) self.spans.append(span) + return span def pop_span(self, status: SpanStatus = SpanStatus.OK): span = self.spans.pop() @@ -203,12 +204,13 @@ class SpanContextManager: def __init__(self, name: str, attributes: Dict[str, Any] = None): self.name = name self.attributes = attributes + self.span = None def __enter__(self): global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT if context: - context.push_span(self.name, self.attributes) + self.span = context.push_span(self.name, self.attributes) return self def __exit__(self, exc_type, exc_value, traceback): @@ -217,11 +219,24 @@ class SpanContextManager: if context: context.pop_span() + def set_attribute(self, key: str, value: Any): + if self.span: + if self.span.attributes is None: + self.span.attributes = {} + self.span.attributes[key] = value + async def __aenter__(self): - return self.__enter__() + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + self.span = context.push_span(self.name, self.attributes) + return self async def __aexit__(self, exc_type, exc_value, traceback): - self.__exit__(exc_type, exc_value, traceback) + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + context.pop_span() def __call__(self, func: Callable): @wraps(func) @@ -246,3 +261,11 @@ class SpanContextManager: def span(name: str, attributes: Dict[str, Any] = None): return SpanContextManager(name, attributes) + + +def get_current_span() -> Optional[Span]: + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + return context.get_current_span() + return None diff --git a/llama_stack/templates/cerebras/__init__.py b/llama_stack/templates/cerebras/__init__.py new file mode 100644 index 000000000..9f9929b52 --- /dev/null +++ b/llama_stack/templates/cerebras/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .cerebras import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml new file mode 100644 index 000000000..a1fe93099 --- /dev/null +++ b/llama_stack/templates/cerebras/build.yaml @@ -0,0 +1,17 @@ +version: '2' +name: cerebras +distribution_spec: + description: Use Cerebras for running LLM inference + docker_image: null + providers: + inference: + - remote::cerebras + safety: + - inline::llama-guard + memory: + - inline::meta-reference + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference +image_type: conda diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py new file mode 100644 index 000000000..58e05adf8 --- /dev/null +++ b/llama_stack/templates/cerebras/cerebras.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_models.sku_list import all_registered_models + +from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig +from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases + +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::cerebras"], + "safety": ["inline::llama-guard"], + "memory": ["inline::meta-reference"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + } + + inference_provider = Provider( + provider_id="cerebras", + provider_type="remote::cerebras", + config=CerebrasImplConfig.sample_run_config(), + ) + + core_model_to_hf_repo = { + m.descriptor(): m.huggingface_repo for m in all_registered_models() + } + default_models = [ + ModelInput( + model_id=core_model_to_hf_repo[m.llama_model], + provider_model_id=m.provider_model_id, + ) + for m in model_aliases + ] + + return DistributionTemplate( + name="cerebras", + distro_type="self_hosted", + description="Use Cerebras for running LLM inference", + docker_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + default_models=default_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + }, + default_models=default_models, + default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + ), + }, + run_config_env_vars={ + "LLAMASTACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "CEREBRAS_API_KEY": ( + "", + "Cerebras API Key", + ), + }, + ) diff --git a/llama_stack/templates/cerebras/doc_template.md b/llama_stack/templates/cerebras/doc_template.md new file mode 100644 index 000000000..77fc6f478 --- /dev/null +++ b/llama_stack/templates/cerebras/doc_template.md @@ -0,0 +1,60 @@ +# Cerebras Distribution + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }} ({{ model.provider_model_id }})` +{% endfor %} +{% endif %} + + +### Prerequisite: API Keys + +Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/). + + +## Running Llama Stack with Cerebras + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-{{ name }} \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template cerebras --image-type conda +llama stack run ./run.yaml \ + --port 5001 \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml new file mode 100644 index 000000000..0b41f5b76 --- /dev/null +++ b/llama_stack/templates/cerebras/run.yaml @@ -0,0 +1,63 @@ +version: '2' +image_name: cerebras +docker_image: null +conda_env: cerebras +apis: +- agents +- inference +- memory +- safety +- telemetry +providers: + inference: + - provider_id: cerebras + provider_type: remote::cerebras + config: + base_url: https://api.cerebras.ai + api_key: ${env.CEREBRAS_API_KEY} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + memory: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/faiss_store.db + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db +models: +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: null + provider_model_id: llama3.1-8b +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: null + provider_model_id: llama3.1-70b +shields: +- params: null + shield_id: meta-llama/Llama-Guard-3-8B + provider_id: null + provider_shield_id: null +memory_banks: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] diff --git a/requirements.txt b/requirements.txt index 0ff43e246..8698495b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,8 @@ blobfile fire httpx huggingface-hub -llama-models>=0.0.56 -llama-stack-client>=0.0.56 +llama-models>=0.0.57 +llama-stack-client>=0.0.57 prompt-toolkit python-dotenv pydantic>=2 diff --git a/setup.py b/setup.py index 842cbb30d..3d68021dd 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_stack", - version="0.0.56", + version="0.0.57", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama Stack",