mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Merge branch 'meta-llama:main' into feat/litellm_sambanova_usage
This commit is contained in:
commit
716cb09056
145 changed files with 21384 additions and 1283 deletions
80
.github/workflows/integration-tests.yml
vendored
Normal file
80
.github/workflows/integration-tests.yml
vendored
Normal file
|
@ -0,0 +1,80 @@
|
|||
name: Integration tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
ollama:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install Ollama
|
||||
run: |
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
- name: Pull Ollama image
|
||||
run: |
|
||||
ollama pull llama3.2:3b-instruct-fp16
|
||||
|
||||
- name: Start Ollama in background
|
||||
run: |
|
||||
nohup ollama run llama3.2:3b-instruct-fp16 > ollama.log 2>&1 &
|
||||
|
||||
- name: Set Up Environment and Install Dependencies
|
||||
run: |
|
||||
uv sync --extra dev --extra test
|
||||
uv pip install ollama faiss-cpu
|
||||
uv pip install -e .
|
||||
|
||||
- name: Wait for Ollama to start
|
||||
run: |
|
||||
echo "Waiting for Ollama..."
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
||||
echo "Ollama is running!"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Ollama failed to start"
|
||||
ollama ps
|
||||
ollama.log
|
||||
exit 1
|
||||
|
||||
- name: Start Llama Stack server in background
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
# TODO: use "llama stack run"
|
||||
nohup uv run python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
echo "Waiting for Llama Stack server..."
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
|
||||
echo " Llama Stack server is up!"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo " Llama Stack server failed to start"
|
||||
cat server.log
|
||||
exit 1
|
||||
|
||||
- name: Run Inference Integration Tests
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
run: |
|
||||
uv run pytest -v tests/integration/inference --stack-config=ollama --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2
|
76
.github/workflows/providers-build.yml
vendored
Normal file
76
.github/workflows/providers-build.yml
vendored
Normal file
|
@ -0,0 +1,76 @@
|
|||
name: Test Llama Stack Build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'llama_stack/cli/stack/build.py'
|
||||
- 'llama_stack/cli/stack/_build.py'
|
||||
- 'llama_stack/distribution/build.*'
|
||||
- 'llama_stack/distribution/*.sh'
|
||||
- '.github/workflows/providers-build.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'llama_stack/cli/stack/build.py'
|
||||
- 'llama_stack/cli/stack/_build.py'
|
||||
- 'llama_stack/distribution/build.*'
|
||||
- 'llama_stack/distribution/*.sh'
|
||||
- '.github/workflows/providers-build.yml'
|
||||
|
||||
jobs:
|
||||
generate-matrix:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
templates: ${{ steps.set-matrix.outputs.templates }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Generate Template List
|
||||
id: set-matrix
|
||||
run: |
|
||||
templates=$(ls llama_stack/templates/*/*build.yaml | awk -F'/' '{print $(NF-1)}' | jq -R -s -c 'split("\n")[:-1]')
|
||||
echo "templates=$templates" >> "$GITHUB_OUTPUT"
|
||||
|
||||
build:
|
||||
needs: generate-matrix
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
template: ${{ fromJson(needs.generate-matrix.outputs.templates) }}
|
||||
image-type: [venv, container]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install LlamaStack
|
||||
run: |
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install -e .
|
||||
|
||||
- name: Print build dependencies
|
||||
run: |
|
||||
uv run llama stack build --template ${{ matrix.template }} --image-type ${{ matrix.image-type }} --image-name test --print-deps-only
|
||||
|
||||
- name: Run Llama Stack Build
|
||||
run: |
|
||||
uv run llama stack build --template ${{ matrix.template }} --image-type ${{ matrix.image-type }} --image-name test
|
||||
|
||||
- name: Print dependencies in the image
|
||||
if: matrix.image-type == 'venv'
|
||||
run: |
|
||||
source test/bin/activate
|
||||
uv pip list
|
45
.github/workflows/stale_bot.yml
vendored
Normal file
45
.github/workflows/stale_bot.yml
vendored
Normal file
|
@ -0,0 +1,45 @@
|
|||
name: Close stale issues and PRs
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # every day at midnight
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Stale Action
|
||||
uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-label: 'stale'
|
||||
stale-issue-message: >
|
||||
This issue has been automatically marked as stale because it has not had activity within 60 days.
|
||||
It will be automatically closed if no further activity occurs within 30 days.
|
||||
close-issue-message: >
|
||||
This issue has been automatically closed due to inactivity.
|
||||
Please feel free to reopen if you feel it is still relevant!
|
||||
days-before-issue-stale: 60
|
||||
days-before-issue-close: 30
|
||||
stale-pr-label: 'stale'
|
||||
stale-pr-message: >
|
||||
This pull request has been automatically marked as stale because it has not had activity within 60 days.
|
||||
It will be automatically closed if no further activity occurs within 30 days.
|
||||
close-pr-message: >
|
||||
This pull request has been automatically closed due to inactivity.
|
||||
Please feel free to reopen if you intend to continue working on it!
|
||||
days-before-pr-stale: 60
|
||||
days-before-pr-close: 30
|
||||
operations-per-run: 300
|
5
.github/workflows/unit-tests.yml
vendored
5
.github/workflows/unit-tests.yml
vendored
|
@ -1,6 +1,8 @@
|
|||
name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
|
@ -31,7 +33,7 @@ jobs:
|
|||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
uv run --python ${{ matrix.python }} --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --cov=llama_stack -s -v tests/unit/ --junitxml=pytest-report-${{ matrix.python }}.xml
|
||||
PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --cov=llama_stack --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }}
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
|
@ -41,4 +43,5 @@ jobs:
|
|||
path: |
|
||||
.pytest_cache/
|
||||
pytest-report-${{ matrix.python }}.xml
|
||||
htmlcov-${{ matrix.python }}/
|
||||
retention-days: 7
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -22,3 +22,4 @@ pyrightconfig.json
|
|||
venv/
|
||||
pytest-report.xml
|
||||
.coverage
|
||||
.python-version
|
||||
|
|
|
@ -8,6 +8,7 @@ repos:
|
|||
rev: v5.0.0 # Latest stable version
|
||||
hooks:
|
||||
- id: check-merge-conflict
|
||||
args: ['--assume-in-merge']
|
||||
- id: trailing-whitespace
|
||||
exclude: '\.py$' # Exclude Python files as Ruff already handles them
|
||||
- id: check-added-large-files
|
||||
|
@ -82,6 +83,17 @@ repos:
|
|||
require_serial: true
|
||||
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: openapi-codegen
|
||||
name: API Spec Codegen
|
||||
additional_dependencies:
|
||||
- uv==0.6.2
|
||||
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null 2>&1'
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
3.10
|
|
@ -61,6 +61,7 @@ outlined on that page and do not file a public issue.
|
|||
|
||||
We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments.
|
||||
You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/).
|
||||
|
||||
You can install the dependencies by running:
|
||||
|
||||
```bash
|
||||
|
@ -70,6 +71,11 @@ uv pip install -e .
|
|||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> You can pin a specific version of Python to use for `uv` by adding a `.python-version` file in the root project directory.
|
||||
> Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`.
|
||||
> For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/).
|
||||
|
||||
Note that you can create a dotenv file `.env` that includes necessary environment variables:
|
||||
```
|
||||
LLAMA_STACK_BASE_URL=http://localhost:8321
|
||||
|
@ -102,6 +108,22 @@ uv run pre-commit run --all-files
|
|||
> [!CAUTION]
|
||||
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
||||
|
||||
## Running unit tests
|
||||
|
||||
You can run the unit tests by running:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
./scripts/unit-tests.sh
|
||||
```
|
||||
|
||||
If you'd like to run for a non-default version of Python (currently 3.10), pass `PYTHON_VERSION` variable as follows:
|
||||
|
||||
```
|
||||
source .venv/bin/activate
|
||||
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
|
||||
```
|
||||
|
||||
## Adding a new dependency to the project
|
||||
|
||||
To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run:
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
[](https://pypi.org/project/llama-stack/)
|
||||
[](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
|
||||
[](https://discord.gg/llama-stack)
|
||||

|
||||
|
||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
||||
|
||||
|
@ -50,6 +51,10 @@ Here is a list of the various API providers and available distributions that can
|
|||
| PG Vector | Single Node | | | ✅ | | |
|
||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
||||
| vLLM | Hosted and Single Node | | ✅ | | | |
|
||||
| OpenAI | Hosted | | ✅ | | | |
|
||||
| Anthropic | Hosted | | ✅ | | | |
|
||||
| Gemini | Hosted | | ✅ | | | |
|
||||
|
||||
|
||||
### Distributions
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"cerebras": [
|
||||
|
@ -62,6 +63,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -97,6 +99,7 @@
|
|||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -132,6 +135,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -168,6 +172,7 @@
|
|||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -203,6 +208,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -236,6 +242,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"hf-endpoint": [
|
||||
|
@ -270,6 +277,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"hf-serverless": [
|
||||
|
@ -304,6 +312,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -344,6 +353,7 @@
|
|||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"zmq"
|
||||
],
|
||||
|
@ -385,6 +395,7 @@
|
|||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"zmq"
|
||||
],
|
||||
|
@ -417,6 +428,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"ollama": [
|
||||
|
@ -451,6 +463,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"open-benchmark": [
|
||||
|
@ -485,8 +498,44 @@
|
|||
"together",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn"
|
||||
],
|
||||
"passthrough": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"remote-vllm": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
|
@ -517,6 +566,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -585,6 +635,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -620,6 +671,7 @@
|
|||
"together",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
|
@ -654,6 +706,7 @@
|
|||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"vllm",
|
||||
"sentence-transformers --no-deps",
|
||||
|
|
273
docs/_static/llama-stack-spec.html
vendored
273
docs/_static/llama-stack-spec.html
vendored
|
@ -6,8 +6,8 @@
|
|||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>OpenAPI specification</title>
|
||||
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
|
||||
<script type="module" src="https://unpkg.com/@stoplight/elements/web-components.min.js"></script>
|
||||
<link rel="stylesheet" href="https://unpkg.com/@stoplight/elements/styles.min.css">
|
||||
<script type="module" src="https://cdn.jsdelivr.net/npm/@stoplight/elements/web-components.min.js"></script>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@stoplight/elements/styles.min.css">
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
|
@ -2151,6 +2151,48 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/providers/{provider_id}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ProviderInfo"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Providers"
|
||||
],
|
||||
"description": "",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "provider_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/tool-runtime/invoke": {
|
||||
"post": {
|
||||
"responses": {
|
||||
|
@ -2642,7 +2684,7 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/inspect/providers": {
|
||||
"/v1/providers": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -4347,24 +4389,6 @@
|
|||
"type": "string",
|
||||
"description": "Unique identifier for the tool call this response is for"
|
||||
},
|
||||
"tool_name": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"brave_search",
|
||||
"wolfram_alpha",
|
||||
"photogen",
|
||||
"code_interpreter"
|
||||
],
|
||||
"title": "BuiltinTool"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"description": "Name of the tool that was called"
|
||||
},
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"description": "The response content from the tool"
|
||||
|
@ -4374,7 +4398,6 @@
|
|||
"required": [
|
||||
"role",
|
||||
"call_id",
|
||||
"tool_name",
|
||||
"content"
|
||||
],
|
||||
"title": "ToolResponseMessage",
|
||||
|
@ -4549,7 +4572,7 @@
|
|||
"metrics": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/MetricEvent"
|
||||
"$ref": "#/components/schemas/MetricInResponse"
|
||||
}
|
||||
},
|
||||
"completion_message": {
|
||||
|
@ -4571,46 +4594,9 @@
|
|||
"title": "ChatCompletionResponse",
|
||||
"description": "Response from a chat completion request."
|
||||
},
|
||||
"MetricEvent": {
|
||||
"MetricInResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"trace_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"span_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"attributes": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "metric",
|
||||
"default": "metric"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string"
|
||||
},
|
||||
|
@ -4630,15 +4616,10 @@
|
|||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"trace_id",
|
||||
"span_id",
|
||||
"timestamp",
|
||||
"type",
|
||||
"metric",
|
||||
"value",
|
||||
"unit"
|
||||
"value"
|
||||
],
|
||||
"title": "MetricEvent"
|
||||
"title": "MetricInResponse"
|
||||
},
|
||||
"TokenLogProbs": {
|
||||
"type": "object",
|
||||
|
@ -4715,6 +4696,12 @@
|
|||
"CompletionResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metrics": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/MetricInResponse"
|
||||
}
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The generated completion text"
|
||||
|
@ -4924,7 +4911,7 @@
|
|||
"metrics": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/MetricEvent"
|
||||
"$ref": "#/components/schemas/MetricInResponse"
|
||||
}
|
||||
},
|
||||
"event": {
|
||||
|
@ -5082,6 +5069,12 @@
|
|||
"CompletionResponseStreamChunk": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metrics": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/MetricInResponse"
|
||||
}
|
||||
},
|
||||
"delta": {
|
||||
"type": "string",
|
||||
"description": "New content generated since last chunk. This can be one or more tokens."
|
||||
|
@ -7961,6 +7954,53 @@
|
|||
],
|
||||
"title": "InsertChunksRequest"
|
||||
},
|
||||
"ProviderInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_type": {
|
||||
"type": "string"
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"api",
|
||||
"provider_id",
|
||||
"provider_type",
|
||||
"config"
|
||||
],
|
||||
"title": "ProviderInfo"
|
||||
},
|
||||
"InvokeToolRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -8173,27 +8213,6 @@
|
|||
],
|
||||
"title": "ListModelsResponse"
|
||||
},
|
||||
"ProviderInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider_type": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"api",
|
||||
"provider_id",
|
||||
"provider_type"
|
||||
],
|
||||
"title": "ProviderInfo"
|
||||
},
|
||||
"ListProvidersResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -8363,6 +8382,75 @@
|
|||
],
|
||||
"title": "LogSeverity"
|
||||
},
|
||||
"MetricEvent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"trace_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"span_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"attributes": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "metric",
|
||||
"default": "metric"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
}
|
||||
]
|
||||
},
|
||||
"unit": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"trace_id",
|
||||
"span_id",
|
||||
"timestamp",
|
||||
"type",
|
||||
"metric",
|
||||
"value",
|
||||
"unit"
|
||||
],
|
||||
"title": "MetricEvent"
|
||||
},
|
||||
"SpanEndPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -10125,6 +10213,10 @@
|
|||
{
|
||||
"name": "PostTraining (Coming Soon)"
|
||||
},
|
||||
{
|
||||
"name": "Providers",
|
||||
"x-displayName": "Providers API for inspecting, listing, and modifying providers and their configurations."
|
||||
},
|
||||
{
|
||||
"name": "Safety"
|
||||
},
|
||||
|
@ -10171,6 +10263,7 @@
|
|||
"Inspect",
|
||||
"Models",
|
||||
"PostTraining (Coming Soon)",
|
||||
"Providers",
|
||||
"Safety",
|
||||
"Scoring",
|
||||
"ScoringFunctions",
|
||||
|
|
169
docs/_static/llama-stack-spec.yaml
vendored
169
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1444,6 +1444,34 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/InsertChunksRequest'
|
||||
required: true
|
||||
/v1/providers/{provider_id}:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ProviderInfo'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Providers
|
||||
description: ''
|
||||
parameters:
|
||||
- name: provider_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/tool-runtime/invoke:
|
||||
post:
|
||||
responses:
|
||||
|
@ -1782,7 +1810,7 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/RegisterModelRequest'
|
||||
required: true
|
||||
/v1/inspect/providers:
|
||||
/v1/providers:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
|
@ -2943,17 +2971,6 @@ components:
|
|||
type: string
|
||||
description: >-
|
||||
Unique identifier for the tool call this response is for
|
||||
tool_name:
|
||||
oneOf:
|
||||
- type: string
|
||||
enum:
|
||||
- brave_search
|
||||
- wolfram_alpha
|
||||
- photogen
|
||||
- code_interpreter
|
||||
title: BuiltinTool
|
||||
- type: string
|
||||
description: Name of the tool that was called
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: The response content from the tool
|
||||
|
@ -2961,7 +2978,6 @@ components:
|
|||
required:
|
||||
- role
|
||||
- call_id
|
||||
- tool_name
|
||||
- content
|
||||
title: ToolResponseMessage
|
||||
description: >-
|
||||
|
@ -3101,7 +3117,7 @@ components:
|
|||
metrics:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/MetricEvent'
|
||||
$ref: '#/components/schemas/MetricInResponse'
|
||||
completion_message:
|
||||
$ref: '#/components/schemas/CompletionMessage'
|
||||
description: The complete response message
|
||||
|
@ -3116,29 +3132,9 @@ components:
|
|||
- completion_message
|
||||
title: ChatCompletionResponse
|
||||
description: Response from a chat completion request.
|
||||
MetricEvent:
|
||||
MetricInResponse:
|
||||
type: object
|
||||
properties:
|
||||
trace_id:
|
||||
type: string
|
||||
span_id:
|
||||
type: string
|
||||
timestamp:
|
||||
type: string
|
||||
format: date-time
|
||||
attributes:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: integer
|
||||
- type: number
|
||||
- type: boolean
|
||||
- type: 'null'
|
||||
type:
|
||||
type: string
|
||||
const: metric
|
||||
default: metric
|
||||
metric:
|
||||
type: string
|
||||
value:
|
||||
|
@ -3149,14 +3145,9 @@ components:
|
|||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- trace_id
|
||||
- span_id
|
||||
- timestamp
|
||||
- type
|
||||
- metric
|
||||
- value
|
||||
- unit
|
||||
title: MetricEvent
|
||||
title: MetricInResponse
|
||||
TokenLogProbs:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -3213,6 +3204,10 @@ components:
|
|||
CompletionResponse:
|
||||
type: object
|
||||
properties:
|
||||
metrics:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/MetricInResponse'
|
||||
content:
|
||||
type: string
|
||||
description: The generated completion text
|
||||
|
@ -3412,7 +3407,7 @@ components:
|
|||
metrics:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/MetricEvent'
|
||||
$ref: '#/components/schemas/MetricInResponse'
|
||||
event:
|
||||
$ref: '#/components/schemas/ChatCompletionResponseEvent'
|
||||
description: The event containing the new content
|
||||
|
@ -3531,6 +3526,10 @@ components:
|
|||
CompletionResponseStreamChunk:
|
||||
type: object
|
||||
properties:
|
||||
metrics:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/MetricInResponse'
|
||||
delta:
|
||||
type: string
|
||||
description: >-
|
||||
|
@ -5438,6 +5437,32 @@ components:
|
|||
- vector_db_id
|
||||
- chunks
|
||||
title: InsertChunksRequest
|
||||
ProviderInfo:
|
||||
type: object
|
||||
properties:
|
||||
api:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
provider_type:
|
||||
type: string
|
||||
config:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- api
|
||||
- provider_id
|
||||
- provider_type
|
||||
- config
|
||||
title: ProviderInfo
|
||||
InvokeToolRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -5573,21 +5598,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListModelsResponse
|
||||
ProviderInfo:
|
||||
type: object
|
||||
properties:
|
||||
api:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
provider_type:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- api
|
||||
- provider_id
|
||||
- provider_type
|
||||
title: ProviderInfo
|
||||
ListProvidersResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -5703,6 +5713,47 @@ components:
|
|||
- error
|
||||
- critical
|
||||
title: LogSeverity
|
||||
MetricEvent:
|
||||
type: object
|
||||
properties:
|
||||
trace_id:
|
||||
type: string
|
||||
span_id:
|
||||
type: string
|
||||
timestamp:
|
||||
type: string
|
||||
format: date-time
|
||||
attributes:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: integer
|
||||
- type: number
|
||||
- type: boolean
|
||||
- type: 'null'
|
||||
type:
|
||||
type: string
|
||||
const: metric
|
||||
default: metric
|
||||
metric:
|
||||
type: string
|
||||
value:
|
||||
oneOf:
|
||||
- type: integer
|
||||
- type: number
|
||||
unit:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- trace_id
|
||||
- span_id
|
||||
- timestamp
|
||||
- type
|
||||
- metric
|
||||
- value
|
||||
- unit
|
||||
title: MetricEvent
|
||||
SpanEndPayload:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6820,6 +6871,9 @@ tags:
|
|||
- name: Inspect
|
||||
- name: Models
|
||||
- name: PostTraining (Coming Soon)
|
||||
- name: Providers
|
||||
x-displayName: >-
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
- name: Safety
|
||||
- name: Scoring
|
||||
- name: ScoringFunctions
|
||||
|
@ -6844,6 +6898,7 @@ x-tagGroups:
|
|||
- Inspect
|
||||
- Models
|
||||
- PostTraining (Coming Soon)
|
||||
- Providers
|
||||
- Safety
|
||||
- Scoring
|
||||
- ScoringFunctions
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>OpenAPI specification</title>
|
||||
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
|
||||
<script type="module" src="https://unpkg.com/@stoplight/elements/web-components.min.js"></script>
|
||||
<link rel="stylesheet" href="https://unpkg.com/@stoplight/elements/styles.min.css">
|
||||
<script type="module" src="https://cdn.jsdelivr.net/npm/@stoplight/elements/web-components.min.js"></script>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@stoplight/elements/styles.min.css">
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
|
|
|
@ -71,4 +71,4 @@ While there is a lot of flexibility to mix-and-match providers, often users will
|
|||
**Locally Hosted Distro**: You may want to run Llama Stack on your own hardware. Typically though, you still need to use Inference via an external service. You can use providers like HuggingFace TGI, Fireworks, Together, etc. for this purpose. Or you may have access to GPUs and can run a [vLLM](https://github.com/vllm-project/vllm) or [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) instance. If you "just" have a regular desktop machine, you can use [Ollama](https://ollama.com/) for inference. To provide convenient quick access to these options, we provide a number of such pre-configured locally-hosted Distros.
|
||||
|
||||
|
||||
**On-device Distro**: Finally, you may want to run Llama Stack directly on an edge device (mobile phone or a tablet.) We provide Distros for iOS and Android (coming soon.)
|
||||
**On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/ios_sdk.html) and [Android](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/android_sdk.html)
|
||||
|
|
|
@ -8,12 +8,12 @@ Features:
|
|||
- Remote Inferencing: Perform inferencing tasks remotely with Llama models hosted on a remote connection (or serverless localhost).
|
||||
- Simple Integration: With easy-to-use APIs, a developer can quickly integrate Llama Stack in their Android app. The difference with local vs remote inferencing is also minimal.
|
||||
|
||||
Latest Release Notes: [v0.0.58](https://github.com/meta-llama/llama-stack-client-kotlin/releases/tag/v0.0.58)
|
||||
Latest Release Notes: [link](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release)
|
||||
|
||||
*Tagged releases are stable versions of the project. While we strive to maintain a stable main branch, it's not guaranteed to be free of bugs or issues.*
|
||||
|
||||
## Android Demo App
|
||||
Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-apps/tree/android-kotlin-app-latest/examples/android_app)
|
||||
Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-client-kotlin/tree/examples/android_app)
|
||||
|
||||
The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlamaStackRemoteInference.kts`, and `MainActivity.java`. With encompassed business logic, the app shows how to use Llama Stack for both the environments.
|
||||
|
||||
|
@ -24,7 +24,7 @@ The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlama
|
|||
Add the following dependency in your `build.gradle.kts` file:
|
||||
```
|
||||
dependencies {
|
||||
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.0.58")
|
||||
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.1.4.2")
|
||||
}
|
||||
```
|
||||
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
|
||||
|
@ -36,13 +36,13 @@ If you plan on doing remote inferencing this is sufficient to get started.
|
|||
For local inferencing, it is required to include the ExecuTorch library into your app.
|
||||
|
||||
Include the ExecuTorch library by:
|
||||
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/blob/release/0.0.58/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
|
||||
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
|
||||
2. Move the script to the top level of your Android app where the app directory resides:
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/meta-llama/llama-stack-client-kotlin/refs/heads/release/0.0.58/doc/img/example_android_app_directory.png" style="width:300px">
|
||||
<img src="https://github.com/meta-llama/llama-stack-client-kotlin/blob/latest-release/doc/img/example_android_app_directory.png" style="width:300px">
|
||||
</p>
|
||||
|
||||
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate with commit: [0a12e33](https://github.com/pytorch/executorch/commit/0a12e33d22a3d44d1aa2af5f0d0673d45b962553).
|
||||
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate.
|
||||
4. Add the `executorch.aar` dependency in your `build.gradle.kts` file:
|
||||
```
|
||||
dependencies {
|
||||
|
@ -58,12 +58,12 @@ Breaking down the demo app, this section will show the core pieces that are used
|
|||
### Setup Remote Inferencing
|
||||
Start a Llama Stack server on localhost. Here is an example of how you can do this using the firework.ai distribution:
|
||||
```
|
||||
conda create -n stack-fireworks python=3.10
|
||||
conda create -n stack-fireworks python=3.10
|
||||
conda activate stack-fireworks
|
||||
pip install llama-stack=0.0.58
|
||||
pip install --no-cache llama-stack==0.1.4
|
||||
llama stack build --template fireworks --image-type conda
|
||||
export FIREWORKS_API_KEY=<SOME_KEY>
|
||||
llama stack run /Users/<your_username>/.llama/distributions/llamastack-fireworks/fireworks-run.yaml --port=5050
|
||||
llama stack run fireworks --port 5050
|
||||
```
|
||||
|
||||
Ensure the Llama Stack server version is the same as the Kotlin SDK Library for maximum compatibility.
|
||||
|
@ -146,7 +146,7 @@ The purpose of this section is to share more details with users that would like
|
|||
### Prerequisite
|
||||
|
||||
You must complete the following steps:
|
||||
1. Clone the repo (`git clone https://github.com/meta-llama/llama-stack-client-kotlin.git -b release/0.0.58`)
|
||||
1. Clone the repo (`git clone https://github.com/meta-llama/llama-stack-client-kotlin.git -b latest-release`)
|
||||
2. Port the appropriate ExecuTorch libraries over into your Llama Stack Kotlin library environment.
|
||||
```
|
||||
cd llama-stack-client-kotlin-client-local
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# iOS SDK
|
||||
|
||||
We offer both remote and on-device use of Llama Stack in Swift via two components:
|
||||
|
||||
1. [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift/)
|
||||
2. [LocalInferenceImpl](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/ios/inference)
|
||||
We offer both remote and on-device use of Llama Stack in Swift via a single SDK [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift/) that contains two components:
|
||||
1. LlamaStackClient for remote
|
||||
2. Local Inference for on-device
|
||||
|
||||
```{image} ../../../_static/remote_or_local.gif
|
||||
:alt: Seamlessly switching between local, on-device inference and remote hosted inference
|
||||
|
@ -42,7 +41,7 @@ let request = Components.Schemas.CreateAgentTurnRequest(
|
|||
// ...
|
||||
```
|
||||
|
||||
Check out [iOSCalendarAssistant](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/ios_calendar_assistant) for a complete app demo.
|
||||
Check out [iOSCalendarAssistant](https://github.com/meta-llama/llama-stack-client-swift/tree/main/examples/ios_calendar_assistant) for a complete app demo.
|
||||
|
||||
## LocalInference
|
||||
|
||||
|
@ -58,7 +57,7 @@ let inference = LocalInference(queue: runnerQueue)
|
|||
let agents = LocalAgents(inference: self.inference)
|
||||
```
|
||||
|
||||
Check out [iOSCalendarAssistantWithLocalInf](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/ios_calendar_assistant) for a complete app demo.
|
||||
Check out [iOSCalendarAssistantWithLocalInf](https://github.com/meta-llama/llama-stack-client-swift/tree/main/examples/ios_calendar_assistant) for a complete app demo.
|
||||
|
||||
### Installation
|
||||
|
||||
|
@ -68,47 +67,6 @@ We're working on making LocalInference easier to set up. For now, you'll need t
|
|||
1. Install [Cmake](https://cmake.org/) for the executorch build`
|
||||
1. Drag `LocalInference.xcodeproj` into your project
|
||||
1. Add `LocalInference` as a framework in your app target
|
||||
1. Add a package dependency on https://github.com/pytorch/executorch (branch latest)
|
||||
1. Add all the kernels / backends from executorch (but not exectuorch itself!) as frameworks in your app target:
|
||||
- backend_coreml
|
||||
- backend_mps
|
||||
- backend_xnnpack
|
||||
- kernels_custom
|
||||
- kernels_optimized
|
||||
- kernels_portable
|
||||
- kernels_quantized
|
||||
1. In "Build Settings" > "Other Linker Flags" > "Any iOS Simulator SDK", add:
|
||||
```
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
|
||||
```
|
||||
|
||||
1. In "Build Settings" > "Other Linker Flags" > "Any iOS SDK", add:
|
||||
|
||||
```
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_optimized-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_custom-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libkernels_quantized-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_xnnpack-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_coreml-simulator-release.a
|
||||
-force_load
|
||||
$(BUILT_PRODUCTS_DIR)/libbackend_mps-simulator-release.a
|
||||
```
|
||||
|
||||
### Preparing a model
|
||||
|
||||
|
|
42
docs/source/distributions/self_hosted_distro/passthrough.md
Normal file
42
docs/source/distributions/self_hosted_distro/passthrough.md
Normal file
|
@ -0,0 +1,42 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||
# Passthrough Distribution
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 2
|
||||
:hidden:
|
||||
|
||||
self
|
||||
```
|
||||
|
||||
The `llamastack/distribution-passthrough` distribution consists of the following provider configurations.
|
||||
|
||||
| API | Provider(s) |
|
||||
|-----|-------------|
|
||||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::passthrough`, `inline::sentence-transformers` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `PASSTHROUGH_API_KEY`: Passthrough API Key (default: ``)
|
||||
- `PASSTHROUGH_URL`: Passthrough URL (default: ``)
|
||||
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
- `llama3.1-8b-instruct `
|
||||
- `llama3.2-11b-vision-instruct `
|
|
@ -88,11 +88,19 @@ docker run -it \
|
|||
|
||||
:::{dropdown} Installing the Llama Stack client CLI and SDK
|
||||
|
||||
You can interact with the Llama Stack server using various client SDKs. We will use the Python SDK which you can install using the following command. Note that you must be using Python 3.10 or newer:
|
||||
You can interact with the Llama Stack server using various client SDKs. Note that you must be using Python 3.10 or newer. We will use the Python SDK which you can install via `conda` or `virtualenv`.
|
||||
|
||||
For `conda`:
|
||||
```bash
|
||||
yes | conda create -n stack-client python=3.10
|
||||
conda activate stack-client
|
||||
pip install llama-stack-client
|
||||
```
|
||||
|
||||
For `virtualenv`:
|
||||
```bash
|
||||
python -m venv stack-client
|
||||
source stack-client/bin/activate
|
||||
pip install llama-stack-client
|
||||
```
|
||||
|
||||
|
@ -173,6 +181,13 @@ response = client.inference.chat_completion(
|
|||
print(response.completion_message.content)
|
||||
```
|
||||
|
||||
To run the above example, put the code in a file called `inference.py`, ensure your `conda` or `virtualenv` environment is active, and run the following:
|
||||
```bash
|
||||
pip install llama_stack
|
||||
llama stack build --template ollama --image-type <conda|venv>
|
||||
python inference.py
|
||||
```
|
||||
|
||||
### 4. Your first RAG agent
|
||||
|
||||
Here is an example of a simple RAG (Retrieval Augmented Generation) chatbot agent which can answer questions about TorchTune documentation.
|
||||
|
@ -273,6 +288,13 @@ for prompt in user_prompts:
|
|||
log.print()
|
||||
```
|
||||
|
||||
To run the above example, put the code in a file called `rag.py`, ensure your `conda` or `virtualenv` environment is active, and run the following:
|
||||
```bash
|
||||
pip install llama_stack
|
||||
llama stack build --template ollama --image-type <conda|venv>
|
||||
python rag.py
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
- Learn more about Llama Stack [Concepts](../concepts/index.md)
|
||||
|
|
|
@ -38,9 +38,9 @@ We have a number of client-side SDKs available for different languages.
|
|||
| **Language** | **Client SDK** | **Package** |
|
||||
| :----: | :----: | :----: |
|
||||
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [](https://pypi.org/project/llama_stack_client/)
|
||||
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
|
||||
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift/tree/latest-release) | [](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
|
||||
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [](https://npmjs.org/package/llama-stack-client)
|
||||
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
|
||||
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release) | [](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
|
||||
|
||||
## Supported Llama Stack Implementations
|
||||
|
||||
|
@ -61,6 +61,10 @@ A number of "adapters" are available for some popular Inference and Vector Store
|
|||
| Groq | Hosted |
|
||||
| SambaNova | Hosted |
|
||||
| PyTorch ExecuTorch | On-device iOS, Android |
|
||||
| OpenAI | Hosted |
|
||||
| Anthropic | Hosted |
|
||||
| Gemini | Hosted |
|
||||
|
||||
|
||||
**Vector IO API**
|
||||
| **Provider** | **Environments** |
|
||||
|
|
|
@ -14,6 +14,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
|
|
|
@ -117,13 +117,11 @@ class ToolResponseMessage(BaseModel):
|
|||
|
||||
:param role: Must be "tool" to identify this as a tool response
|
||||
:param call_id: Unique identifier for the tool call this response is for
|
||||
:param tool_name: Name of the tool that was called
|
||||
:param content: The response content from the tool
|
||||
"""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
content: InterleavedContent
|
||||
|
||||
|
||||
|
|
|
@ -11,13 +11,6 @@ from pydantic import BaseModel
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
route: str
|
||||
|
@ -32,14 +25,21 @@ class HealthInfo(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
version: str
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
version: str
|
||||
|
||||
|
||||
class ListRoutesResponse(BaseModel):
|
||||
data: List[RouteInfo]
|
||||
|
||||
|
|
|
@ -4,9 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
from .providers import * # noqa: F401 F403
|
36
llama_stack/apis/providers/providers.py
Normal file
36
llama_stack/apis/providers/providers.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
|
@ -96,6 +96,13 @@ class MetricEvent(EventCommon):
|
|||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricInResponse(BaseModel):
|
||||
metric: str
|
||||
value: Union[int, float]
|
||||
unit: Optional[str] = None
|
||||
|
||||
|
||||
# This is a short term solution to allow inference API to return metrics
|
||||
# The ideal way to do this is to have a way for all response types to include metrics
|
||||
# and all metric events logged to the telemetry API to be inlcuded with the response
|
||||
|
@ -117,7 +124,7 @@ class MetricEvent(EventCommon):
|
|||
|
||||
|
||||
class MetricResponseMixin(BaseModel):
|
||||
metrics: Optional[List[MetricEvent]] = None
|
||||
metrics: Optional[List[MetricInResponse]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -10,7 +10,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
|||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
if datetime.now() > manifest.expires_on:
|
||||
if datetime.now(timezone.utc) > manifest.expires_on:
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
|
|
|
@ -41,8 +41,8 @@ class ModelPromptFormat(Subcommand):
|
|||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="llama3_1",
|
||||
help="Model Family (llama3_1, llama3_X, etc.)",
|
||||
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
|
||||
"(Run `llama model list` to see a list of valid model names)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-l",
|
||||
|
@ -60,7 +60,6 @@ class ModelPromptFormat(Subcommand):
|
|||
]
|
||||
|
||||
model_list = [m.value for m in supported_model_ids]
|
||||
model_str = "\n".join(model_list)
|
||||
|
||||
if args.list:
|
||||
headers = ["Model(s)"]
|
||||
|
@ -81,10 +80,16 @@ class ModelPromptFormat(Subcommand):
|
|||
try:
|
||||
model_id = CoreModelId(args.model_name)
|
||||
except ValueError:
|
||||
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
|
||||
self.parser.error(
|
||||
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
if model_id not in supported_model_ids:
|
||||
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
||||
self.parser.error(
|
||||
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||
|
|
|
@ -62,7 +62,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
|
|
@ -117,6 +117,14 @@ class Provider(BaseModel):
|
|||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
category_levels: Dict[str, str] = Field(
|
||||
default_factory=Dict,
|
||||
description="""
|
||||
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
||||
)
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
|
@ -176,6 +184,8 @@ a default SQLite store will be used.""",
|
|||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
|
||||
server: ServerConfig = Field(
|
||||
default_factory=ServerConfig,
|
||||
description="Configuration for the HTTP(S) server",
|
||||
|
|
|
@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
|
||||
def providable_apis() -> List[Api]:
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
|
||||
|
||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
|
|
59
llama_stack/distribution/providers.py
Normal file
59
llama_stack/distribution/providers.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .stack import redact_sensitive_fields
|
||||
|
||||
|
||||
class ProviderImplConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = ProviderImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ProviderImpl(Providers):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
all_providers = await self.list_providers()
|
||||
for p in all_providers.data:
|
||||
if p.provider_id == provider_id:
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
|
@ -59,6 +60,7 @@ class InvalidProviderError(Exception):
|
|||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.providers: ProvidersAPI,
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
|
@ -247,6 +249,25 @@ def sort_providers_by_deps(
|
|||
)
|
||||
)
|
||||
|
||||
sorted_providers.append(
|
||||
(
|
||||
"providers",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.providers,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
||||
module="llama_stack.distribution.providers",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
|
|
|
@ -48,7 +48,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
@ -206,12 +206,12 @@ class InferenceRouter(Inference):
|
|||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
) -> List[MetricInResponse]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return metrics
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
|
@ -238,7 +238,6 @@ class InferenceRouter(Inference):
|
|||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logger.debug(
|
||||
"core",
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
|
|
|
@ -25,7 +25,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
|
@ -306,34 +306,42 @@ def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
log_line = ""
|
||||
if args.yaml_config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.yaml_config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
logger.info(f"Using config file: {config_file}")
|
||||
log_line = f"Using config file: {config_file}"
|
||||
elif args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
logger.info(f"Using template {args.template} config file: {config_file}")
|
||||
log_line = f"Using template {args.template} config file: {config_file}"
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
logger_config = None
|
||||
with open(config_file, "r") as fp:
|
||||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
# now that the logger is initialized, print the line about which type of config we are using.
|
||||
logger.info(log_line)
|
||||
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
@ -368,6 +376,7 @@ def main():
|
|||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.providers import Providers
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
|
@ -44,6 +45,7 @@ logger = get_logger(name=__name__, category="core")
|
|||
|
||||
|
||||
class LlamaStack(
|
||||
Providers,
|
||||
VectorDBs,
|
||||
Inference,
|
||||
BatchInference,
|
||||
|
|
|
@ -19,7 +19,7 @@ def preserve_contexts_async_generator(
|
|||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
|
||||
async def wrapper():
|
||||
async def wrapper() -> AsyncGenerator[T, None]:
|
||||
while True:
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
|
|
|
@ -7,13 +7,15 @@
|
|||
import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
from rich.logging import RichHandler
|
||||
from termcolor import cprint
|
||||
|
||||
from .distribution.datatypes import LoggingConfig
|
||||
|
||||
# Default log level
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
||||
|
@ -34,6 +36,56 @@ CATEGORIES = [
|
|||
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
||||
|
||||
|
||||
def config_to_category_levels(category: str, level: str):
|
||||
"""
|
||||
Helper function to be called either by environment parsing or yaml parsing to go from a list of categories and levels to a dictionary ready to be
|
||||
used by the logger dictConfig.
|
||||
|
||||
Parameters:
|
||||
category (str): logging category to apply the level to
|
||||
level (str): logging level to be used in the category
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
|
||||
category_levels: Dict[str, int] = {}
|
||||
level_value = logging._nameToLevel.get(str(level).upper())
|
||||
if level_value is None:
|
||||
logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.")
|
||||
return category_levels
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
return category_levels
|
||||
|
||||
|
||||
def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]:
|
||||
"""
|
||||
Helper function to parse a yaml logging configuration found in the run.yaml
|
||||
|
||||
Parameters:
|
||||
yaml_config (Logging): the logger config object found in the run.yaml
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
category_levels = {}
|
||||
for category, level in yaml_config.category_levels.items():
|
||||
category_levels.update(config_to_category_levels(category=category, level=level))
|
||||
|
||||
return category_levels
|
||||
|
||||
|
||||
def parse_environment_config(env_config: str) -> Dict[str, int]:
|
||||
"""
|
||||
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
||||
|
@ -53,25 +105,7 @@ def parse_environment_config(env_config: str) -> Dict[str, int]:
|
|||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
|
||||
|
||||
level_value = logging._nameToLevel.get(level)
|
||||
if level_value is None:
|
||||
logging.warning(
|
||||
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
|
||||
)
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
category_levels.update(config_to_category_levels(category=category, level=level))
|
||||
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
|
||||
|
@ -176,7 +210,9 @@ def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None
|
|||
logger.setLevel(root_level)
|
||||
|
||||
|
||||
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
|
||||
def get_logger(
|
||||
name: str, category: str = "uncategorized", config: Optional[LoggingConfig] | None = None
|
||||
) -> logging.LoggerAdapter:
|
||||
"""
|
||||
Returns a logger with the specified name and category.
|
||||
If no category is provided, defaults to 'uncategorized'.
|
||||
|
@ -184,10 +220,14 @@ def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdap
|
|||
Parameters:
|
||||
name (str): The name of the logger (e.g., module or filename).
|
||||
category (str): The category of the logger (default 'uncategorized').
|
||||
config (Logging): optional yaml config to override the existing logger configuration
|
||||
|
||||
Returns:
|
||||
logging.LoggerAdapter: Configured logger with category support.
|
||||
"""
|
||||
if config:
|
||||
_category_levels.update(parse_yaml_config(config))
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
|
|
@ -34,7 +34,9 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
|||
)
|
||||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{"today": datetime.now().strftime("%d %B %Y")},
|
||||
{
|
||||
"today": datetime.now().strftime("%d %B %Y") # noqa: DTZ005 - we don't care about timezones here since we are displaying the date
|
||||
},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[Any]:
|
||||
|
|
|
@ -11,8 +11,8 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
@ -153,7 +153,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
tool_name=response.tool_name,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
|
@ -181,6 +180,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
|
@ -191,6 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
await self._initialize_tools()
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
|
@ -219,8 +220,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages = await self.get_messages_from_turns(turns)
|
||||
if is_resume:
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||
]
|
||||
messages.extend(tool_response_messages)
|
||||
last_turn = turns[-1]
|
||||
|
@ -239,7 +239,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
now = datetime.now().astimezone().isoformat()
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
|
@ -264,7 +264,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
start_time = last_turn.started_at
|
||||
else:
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
start_time = datetime.now(timezone.utc).isoformat()
|
||||
input_messages = request.messages
|
||||
|
||||
output_message = None
|
||||
|
@ -275,7 +275,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
toolgroups_for_turn=request.toolgroups if not is_resume else None,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
|
@ -296,7 +295,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
@ -327,7 +326,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
@ -350,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
toolgroups_for_turn,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -389,7 +386,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
shield_call_start_time = datetime.now().astimezone().isoformat()
|
||||
shield_call_start_time = datetime.now(timezone.utc).isoformat()
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -413,7 +410,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -436,7 +433,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
violation=None,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -451,30 +448,19 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: simplify all of this code, it can be simpler
|
||||
toolgroup_args = {}
|
||||
toolgroups = set()
|
||||
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroups.add(tool_group_name)
|
||||
toolgroup_args[tool_group_name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||
await self.handle_documents(session_id, documents, input_messages)
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info and session_info.vector_db_id:
|
||||
if RAG_TOOL_GROUP not in toolgroup_args:
|
||||
toolgroup_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
|
||||
else:
|
||||
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
for tool_name in self.tool_name_to_args.keys():
|
||||
if tool_name == MEMORY_QUERY_TOOL:
|
||||
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
|
||||
else:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
|
||||
output_attachments = []
|
||||
|
||||
|
@ -486,7 +472,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now().astimezone().isoformat()
|
||||
inference_start_time = datetime.now(timezone.utc).isoformat()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
|
@ -504,7 +490,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=tool_defs,
|
||||
tools=self.tool_defs,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||
response_format=self.agent_config.response_format,
|
||||
stream=True,
|
||||
|
@ -596,7 +582,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -667,7 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[],
|
||||
started_at=datetime.now().astimezone().isoformat(),
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
yield message
|
||||
|
@ -684,14 +670,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
||||
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_result = await execute_tool_call_maybe(
|
||||
self.tool_runtime_api,
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
tool_call,
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
if tool_result.content is None:
|
||||
raise ValueError(
|
||||
|
@ -700,7 +683,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
result_messages = [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
)
|
||||
]
|
||||
|
@ -720,13 +702,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=result_message.call_id,
|
||||
tool_name=result_message.tool_name,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=result_message.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -744,9 +726,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
async def _get_tool_defs(
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||
async def _initialize_tools(
|
||||
self,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> None:
|
||||
toolgroup_to_args = {}
|
||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroup_to_args[tool_group_name] = toolgroup.args
|
||||
|
||||
# Determine which tools to include
|
||||
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||
agent_config_toolgroups = []
|
||||
|
@ -755,8 +744,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if name not in agent_config_toolgroups:
|
||||
agent_config_toolgroups.append(name)
|
||||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_to_group = {}
|
||||
tool_name_to_args = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
|
@ -774,53 +765,38 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
if not tools.data:
|
||||
available_tool_groups = ", ".join(
|
||||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||
)
|
||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
||||
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
|
||||
raise ValueError(
|
||||
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
)
|
||||
|
||||
for tool_def in tools.data:
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
tool_name = tool_def.identifier
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
if tool_name == "web_search":
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
identifier: str | BuiltinTool | None = tool_def.identifier
|
||||
if identifier == "web_search":
|
||||
identifier = BuiltinTool.brave_search
|
||||
else:
|
||||
built_in_type = BuiltinTool(tool_name)
|
||||
identifier = BuiltinTool(identifier)
|
||||
else:
|
||||
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
|
||||
if input_tool_name in (None, tool_def.identifier):
|
||||
identifier = tool_def.identifier
|
||||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(built_in_type, None):
|
||||
raise ValueError(f"Tool {built_in_type} already exists")
|
||||
|
||||
tool_name_to_def[built_in_type] = ToolDefinition(
|
||||
tool_name=built_in_type,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
if tool_name_to_def.get(tool_def.identifier, None):
|
||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||
if tool_name in (None, tool_def.identifier):
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
tool_name=identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
|
@ -832,9 +808,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
return list(tool_name_to_def.values()), tool_to_group
|
||||
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
@ -853,15 +829,46 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_group, tool_name = split_names[0], None
|
||||
return tool_group, tool_name
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
self,
|
||||
session_id: str,
|
||||
tool_call: ToolCall,
|
||||
) -> ToolInvocationResult:
|
||||
tool_name = tool_call.tool_name
|
||||
registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
|
||||
if tool_name not in registered_tool_names:
|
||||
raise ValueError(
|
||||
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
|
||||
)
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
if tool_name == BuiltinTool.brave_search:
|
||||
tool_name_str = WEB_SEARCH_TOOL
|
||||
else:
|
||||
tool_name_str = tool_name.value
|
||||
else:
|
||||
tool_name_str = tool_name
|
||||
|
||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
session_id: str,
|
||||
documents: List[Document],
|
||||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
|
@ -989,42 +996,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
|
||||
return ToolResponseMessage(
|
||||
call_id="",
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
tool_runtime_api: ToolRuntime,
|
||||
session_id: str,
|
||||
tool_call: ToolCall,
|
||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||
tool_to_group: Dict[str, str],
|
||||
) -> ToolInvocationResult:
|
||||
name = tool_call.tool_name
|
||||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
raise ValueError(f"Tool {name} not found in any tool group")
|
||||
if isinstance(name, BuiltinTool):
|
||||
if name == BuiltinTool.brave_search:
|
||||
name = WEB_SEARCH_TOOL
|
||||
else:
|
||||
name = name.value
|
||||
|
||||
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logger.info(f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
content: str,
|
||||
) -> Optional[Attachment]:
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -36,7 +36,7 @@ class AgentPersistence:
|
|||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(),
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class LocalFSDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "localfs_datasetio.db").as_posix()
|
||||
) # Uses SQLite config specific to localfs storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="localfs_datasetio.db",
|
||||
)
|
||||
}
|
||||
|
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "meta_reference_eval.db").as_posix()
|
||||
) # Uses SQLite config specific to Meta Reference Eval storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="meta_reference_eval.db",
|
||||
)
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, UserMessage
|
||||
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
|
@ -118,7 +118,7 @@ class MetaReferenceEvalImpl(
|
|||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
|
@ -168,10 +168,11 @@ class MetaReferenceEvalImpl(
|
|||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
@ -40,7 +42,7 @@ class VLLMConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -12,3 +12,9 @@ from pydantic import BaseModel
|
|||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"checkpoint_format": "meta",
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
|
|||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(),
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
|
@ -84,7 +84,7 @@ class TorchtunePostTrainingImpl:
|
|||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now()
|
||||
job_status_response.started_at = datetime.now(timezone.utc)
|
||||
|
||||
await recipe.setup()
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
@ -93,7 +93,7 @@ class TorchtunePostTrainingImpl:
|
|||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now()
|
||||
job_status_response.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
|
|
|
@ -8,7 +8,7 @@ import gc
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
@ -532,7 +532,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
|
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeScannerConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -4,10 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
excluded_categories: List[str] = []
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"excluded_categories": [],
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
@ -23,3 +24,9 @@ class PromptGuardConfig(BaseModel):
|
|||
if v not in [t.value for t in PromptGuardType]:
|
||||
raise ValueError(f"Unknown prompt guard type: {v}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"guard_type": "injection",
|
||||
}
|
||||
|
|
|
@ -3,7 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasicScoringConfig(BaseModel): ...
|
||||
class BasicScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -22,12 +22,19 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
|
||||
FIXED_FNS = [
|
||||
EqualityScoringFn,
|
||||
SubsetOfScoringFn,
|
||||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.bfcl.ast_parser import decode_ast
|
||||
from ..utils.bfcl.checker import ast_checker, is_empty_output
|
||||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
||||
contain_func_call = False
|
||||
error = None
|
||||
error_type = None
|
||||
checker_result = {}
|
||||
try:
|
||||
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
|
||||
contain_func_call = True
|
||||
# if not is_function_calling_format_output(prediction):
|
||||
if is_empty_output(prediction):
|
||||
contain_func_call = False
|
||||
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
|
||||
error_type = "ast_decoder:decoder_wrong_output_format"
|
||||
else:
|
||||
checker_result = ast_checker(
|
||||
json.loads(x["function"]),
|
||||
prediction,
|
||||
json.loads(x["ground_truth"]),
|
||||
x["language"],
|
||||
test_category=test_category,
|
||||
model_name="",
|
||||
)
|
||||
except Exception as e:
|
||||
prediction = ""
|
||||
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
|
||||
error_type = "ast_decoder:decoder_failed"
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"contain_func_call": contain_func_call,
|
||||
"valid": checker_result.get("valid", False),
|
||||
"error": error or checker_result.get("error", ""),
|
||||
"error_type": error_type or checker_result.get("error_type", ""),
|
||||
}
|
||||
|
||||
|
||||
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
return {"valid": x["valid"]}
|
||||
|
||||
|
||||
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
|
||||
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
|
||||
return {"valid": float(acc)}
|
||||
|
||||
|
||||
class BFCLScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for BFCL
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
bfcl.identifier: bfcl,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "bfcl",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
|
||||
score = gen_relevance_acc(score_result)["valid"]
|
||||
else:
|
||||
score = gen_valid(score_result)["valid"]
|
||||
return {
|
||||
"score": float(score),
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
bfcl = ScoringFn(
|
||||
identifier="basic::bfcl",
|
||||
description="BFCL complex scoring",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="bfcl",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
|
@ -3,10 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
|
@ -0,0 +1,296 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import ast
|
||||
|
||||
from .tree_sitter import get_parser
|
||||
|
||||
|
||||
def parse_java_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("java")
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
|
||||
if root_node.has_error:
|
||||
raise Exception("Error parsing java the source code.")
|
||||
|
||||
def get_text(node):
|
||||
"""Returns the text represented by the node."""
|
||||
return source_code[node.start_byte : node.end_byte]
|
||||
|
||||
def traverse_node(node, nested=False):
|
||||
if node.type == "string_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding quotes from string literals
|
||||
return get_text(node)[1:-1]
|
||||
elif node.type == "character_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding single quotes from character literals
|
||||
return get_text(node)[1:-1]
|
||||
"""Traverse the node to collect texts for complex structures."""
|
||||
if node.type in [
|
||||
"identifier",
|
||||
"class_literal",
|
||||
"type_identifier",
|
||||
"method_invocation",
|
||||
]:
|
||||
return get_text(node)
|
||||
elif node.type == "array_creation_expression":
|
||||
# Handle array creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
value_node = node.child_by_field_name("value")
|
||||
type_text = traverse_node(type_node, True)
|
||||
value_text = traverse_node(value_node, True)
|
||||
return f"new {type_text}[]{value_text}"
|
||||
elif node.type == "object_creation_expression":
|
||||
# Handle object creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
type_text = traverse_node(type_node, True)
|
||||
if arguments_node:
|
||||
# Process each argument carefully, avoiding unnecessary punctuation
|
||||
argument_texts = []
|
||||
for child in arguments_node.children:
|
||||
if child.type not in [
|
||||
",",
|
||||
"(",
|
||||
")",
|
||||
]: # Exclude commas and parentheses
|
||||
argument_text = traverse_node(child, True)
|
||||
argument_texts.append(argument_text)
|
||||
arguments_text = ", ".join(argument_texts)
|
||||
return f"new {type_text}({arguments_text})"
|
||||
else:
|
||||
return f"new {type_text}()"
|
||||
elif node.type == "set":
|
||||
# Handling sets specifically
|
||||
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
|
||||
return "{" + ", ".join(items) + "}"
|
||||
|
||||
elif node.child_count > 0:
|
||||
return "".join(traverse_node(child, True) for child in node.children)
|
||||
else:
|
||||
return get_text(node)
|
||||
|
||||
def extract_arguments(args_node):
|
||||
arguments = {}
|
||||
for child in args_node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# For named parameters
|
||||
name_node, value_node = child.children[0], child.children[2]
|
||||
name = get_text(name_node)
|
||||
value = traverse_node(value_node)
|
||||
if name in arguments:
|
||||
if not isinstance(arguments[name], list):
|
||||
arguments[name] = [arguments[name]]
|
||||
arguments[name].append(value)
|
||||
else:
|
||||
arguments[name] = value
|
||||
# arguments.append({'name': name, 'value': value})
|
||||
elif child.type in ["identifier", "class_literal", "set"]:
|
||||
# For unnamed parameters and handling sets
|
||||
value = traverse_node(child)
|
||||
if None in arguments:
|
||||
if not isinstance(arguments[None], list):
|
||||
arguments[None] = [arguments[None]]
|
||||
arguments[None].append(value)
|
||||
else:
|
||||
arguments[None] = value
|
||||
return arguments
|
||||
|
||||
def traverse(node):
|
||||
if node.type == "method_invocation":
|
||||
# Extract the function name and its arguments
|
||||
method_name = get_text(node.child_by_field_name("name"))
|
||||
class_name_node = node.child_by_field_name("object")
|
||||
if class_name_node:
|
||||
class_name = get_text(class_name_node)
|
||||
function_name = f"{class_name}.{method_name}"
|
||||
else:
|
||||
function_name = method_name
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
if arguments_node:
|
||||
arguments = extract_arguments(arguments_node)
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
return [{function_name: arguments}]
|
||||
|
||||
else:
|
||||
for child in node.children:
|
||||
result = traverse(child)
|
||||
if result:
|
||||
return result
|
||||
|
||||
result = traverse(root_node)
|
||||
return result if result else {}
|
||||
|
||||
|
||||
def parse_javascript_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("javascript")
|
||||
# Parse the source code
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
if root_node.has_error:
|
||||
raise Exception("Error js parsing the source code.")
|
||||
|
||||
# Function to recursively extract argument details
|
||||
def extract_arguments(node):
|
||||
args = {}
|
||||
for child in node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# Extract left (name) and right (value) parts of the assignment
|
||||
name = child.children[0].text.decode("utf-8")
|
||||
value = child.children[2].text.decode("utf-8")
|
||||
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1] # Trim the quotation marks
|
||||
if name in args:
|
||||
if not isinstance(args[name], list):
|
||||
args[name] = [args[name]]
|
||||
args[name].append(value)
|
||||
else:
|
||||
args[name] = value
|
||||
|
||||
elif child.type == "identifier" or child.type == "true":
|
||||
# Handle non-named arguments and boolean values
|
||||
value = child.text.decode("utf-8")
|
||||
if None in args:
|
||||
if not isinstance(args[None], list):
|
||||
args[None] = [args[None]]
|
||||
args[None].append(value)
|
||||
else:
|
||||
args[None] = value
|
||||
return args
|
||||
|
||||
# Find the function call and extract its name and arguments
|
||||
if root_node.type == "program":
|
||||
for child in root_node.children:
|
||||
if child.type == "expression_statement":
|
||||
for sub_child in child.children:
|
||||
if sub_child.type == "call_expression":
|
||||
function_name = sub_child.children[0].text.decode("utf8")
|
||||
arguments_node = sub_child.children[1]
|
||||
parameters = extract_arguments(arguments_node)
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
result = [{function_name: parameters}]
|
||||
return result
|
||||
|
||||
|
||||
def ast_parse(input_str, language="Python"):
|
||||
if language == "Python":
|
||||
cleaned_input = input_str.strip("[]'")
|
||||
parsed = ast.parse(cleaned_input, mode="eval")
|
||||
extracted = []
|
||||
if isinstance(parsed.body, ast.Call):
|
||||
extracted.append(resolve_ast_call(parsed.body))
|
||||
else:
|
||||
for elem in parsed.body.elts:
|
||||
extracted.append(resolve_ast_call(elem))
|
||||
return extracted
|
||||
elif language == "Java":
|
||||
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
|
||||
elif language == "JavaScript":
|
||||
return parse_javascript_function_call(input_str[1:-1])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported language: {language}")
|
||||
|
||||
|
||||
def resolve_ast_call(elem):
|
||||
# Handle nested attributes for deeply nested module paths
|
||||
func_parts = []
|
||||
func_part = elem.func
|
||||
while isinstance(func_part, ast.Attribute):
|
||||
func_parts.append(func_part.attr)
|
||||
func_part = func_part.value
|
||||
if isinstance(func_part, ast.Name):
|
||||
func_parts.append(func_part.id)
|
||||
func_name = ".".join(reversed(func_parts))
|
||||
args_dict = {}
|
||||
# Parse when args are simply passed as an unnamed dictionary arg
|
||||
for arg in elem.args:
|
||||
if isinstance(arg, ast.Dict):
|
||||
for key, value in zip(arg.keys, arg.values):
|
||||
if isinstance(key, ast.Constant):
|
||||
arg_name = key.value
|
||||
output = resolve_ast_by_type(value)
|
||||
args_dict[arg_name] = output
|
||||
for arg in elem.keywords:
|
||||
output = resolve_ast_by_type(arg.value)
|
||||
args_dict[arg.arg] = output
|
||||
return {func_name: args_dict}
|
||||
|
||||
|
||||
def resolve_ast_by_type(value):
|
||||
if isinstance(value, ast.Constant):
|
||||
if value.value is Ellipsis:
|
||||
output = "..."
|
||||
else:
|
||||
output = value.value
|
||||
elif isinstance(value, ast.UnaryOp):
|
||||
output = -value.operand.value
|
||||
elif isinstance(value, ast.List):
|
||||
output = [resolve_ast_by_type(v) for v in value.elts]
|
||||
elif isinstance(value, ast.Dict):
|
||||
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
|
||||
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
|
||||
output = value.value
|
||||
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
|
||||
output = eval(ast.unparse(value))
|
||||
elif isinstance(value, ast.Name):
|
||||
output = value.id
|
||||
elif isinstance(value, ast.Call):
|
||||
if len(value.keywords) == 0:
|
||||
output = ast.unparse(value)
|
||||
else:
|
||||
output = resolve_ast_call(value)
|
||||
elif isinstance(value, ast.Tuple):
|
||||
output = tuple(resolve_ast_by_type(v) for v in value.elts)
|
||||
elif isinstance(value, ast.Lambda):
|
||||
output = eval(ast.unparse(value.body[0].value))
|
||||
elif isinstance(value, ast.Ellipsis):
|
||||
output = "..."
|
||||
elif isinstance(value, ast.Subscript):
|
||||
try:
|
||||
output = ast.unparse(value.body[0].value)
|
||||
except:
|
||||
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
||||
else:
|
||||
raise Exception(f"Unsupported AST type: {type(value)}")
|
||||
return output
|
||||
|
||||
|
||||
def decode_ast(result, language="Python"):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decoded_output = ast_parse(func, language)
|
||||
return decoded_output
|
||||
|
||||
|
||||
def decode_execute(result):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decode_output = ast_parse(func)
|
||||
execution_list = []
|
||||
for function_call in decode_output:
|
||||
for key, value in function_call.items():
|
||||
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
|
||||
return execution_list
|
989
llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py
Normal file
989
llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py
Normal file
|
@ -0,0 +1,989 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
# Comment out for now until we actually use the rest checker in evals
|
||||
# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function.
|
||||
|
||||
|
||||
class NoAPIKeyError(Exception):
|
||||
def __init__(self):
|
||||
self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate."
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2
|
||||
|
||||
|
||||
JAVA_TYPE_CONVERSION = {
|
||||
"byte": int,
|
||||
"short": int,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"double": float,
|
||||
"long": int,
|
||||
"boolean": bool,
|
||||
"char": str,
|
||||
"Array": list,
|
||||
"ArrayList": list,
|
||||
"Set": set,
|
||||
"HashMap": dict,
|
||||
"Hashtable": dict,
|
||||
"Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list
|
||||
"Stack": list,
|
||||
"String": str,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
JS_TYPE_CONVERSION = {
|
||||
"String": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"Bigint": int,
|
||||
"Boolean": bool,
|
||||
"dict": dict,
|
||||
"array": list,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# We switch to conditional import for the following two imports to avoid unnecessary installations.
|
||||
# User doesn't need to setup the tree-sitter packages if they are not running the test for that language.
|
||||
# from js_type_converter import js_type_converter
|
||||
# from java_type_converter import java_type_converter
|
||||
|
||||
PYTHON_TYPE_MAPPING = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"tuple": list,
|
||||
"dict": dict,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# This is the list of types that we need to recursively check its values
|
||||
PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"]
|
||||
|
||||
|
||||
NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"]
|
||||
|
||||
|
||||
#### Helper functions for AST ####
|
||||
def find_description(func_descriptions, name):
|
||||
if type(func_descriptions) == list:
|
||||
for func_description in func_descriptions:
|
||||
if func_description["name"] == name:
|
||||
return func_description
|
||||
return None
|
||||
else:
|
||||
# it is a dict, there is only one function
|
||||
return func_descriptions
|
||||
|
||||
|
||||
def get_possible_answer_type(possible_answer: list):
|
||||
for answer in possible_answer:
|
||||
if answer != "": # Optional parameter
|
||||
return type(answer)
|
||||
return None
|
||||
|
||||
|
||||
def type_checker(
|
||||
param: str,
|
||||
value,
|
||||
possible_answer: list,
|
||||
expected_type_description: str,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
):
|
||||
# NOTE: This type checker only supports nested type checking for one level deep.
|
||||
# We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex.
|
||||
|
||||
result: Any = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"is_variable": False,
|
||||
"error_type": "type_error:simple",
|
||||
}
|
||||
|
||||
is_variable = False
|
||||
# check for the case where a variable is used instead of a actual value.
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if possible_answer_type != expected_type_converted:
|
||||
is_variable = True
|
||||
|
||||
# value is the same type as in function description
|
||||
if type(value) == expected_type_converted:
|
||||
# We don't need to do recursive check for simple types
|
||||
if nested_type_converted == None:
|
||||
result["is_variable"] = is_variable
|
||||
return result
|
||||
else:
|
||||
for possible_answer_item in possible_answer:
|
||||
flag = True # Each parameter should match to at least one possible answer type.
|
||||
# Here, we assume that each item should be the same type. We could also relax it.
|
||||
if type(possible_answer_item) == list:
|
||||
for value_item in value:
|
||||
checker_result = type_checker(
|
||||
param,
|
||||
value_item,
|
||||
possible_answer_item,
|
||||
str(nested_type_converted),
|
||||
nested_type_converted,
|
||||
None,
|
||||
)
|
||||
if not checker_result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": [], "is_variable": is_variable}
|
||||
|
||||
result["valid"] = False
|
||||
result["error"] = [
|
||||
f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}."
|
||||
]
|
||||
result["error_type"] = "type_error:nested"
|
||||
|
||||
# value is not as expected, check for the case where a variable is used instead of a actual value
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if type(value) == possible_answer_type:
|
||||
result["is_variable"] = True
|
||||
return result
|
||||
|
||||
result["valid"] = False
|
||||
result["error"].append(
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:simple"
|
||||
return result
|
||||
|
||||
|
||||
def standardize_string(input_string: str):
|
||||
# This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase
|
||||
# It will also convert all the single quotes to double quotes
|
||||
# This is used to compare the model output with the possible answers
|
||||
# We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024
|
||||
regex_string = r"[ \,\.\/\-\_\*\^]"
|
||||
return re.sub(regex_string, "", input_string).lower().replace("'", '"')
|
||||
|
||||
|
||||
def string_checker(param: str, model_output: str, possible_answer: list):
|
||||
standardize_possible_answer = []
|
||||
standardize_model_output = standardize_string(model_output)
|
||||
for i in range(len(possible_answer)):
|
||||
if type(possible_answer[i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[i]))
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive."
|
||||
],
|
||||
"error_type": "value_error:string",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def list_checker(param: str, model_output: list, possible_answer: list):
|
||||
# Convert the tuple to a list
|
||||
|
||||
standardize_model_output = list(model_output)
|
||||
|
||||
# If the element in the list is a string, we need to standardize it
|
||||
for i in range(len(standardize_model_output)):
|
||||
if type(standardize_model_output[i]) == str:
|
||||
standardize_model_output[i] = standardize_string(model_output[i])
|
||||
|
||||
standardize_possible_answer: Any = []
|
||||
# We also need to standardize the possible answers
|
||||
for i in range(len(possible_answer)):
|
||||
standardize_possible_answer.append([])
|
||||
for j in range(len(possible_answer[i])):
|
||||
if type(possible_answer[i][j]) == str:
|
||||
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
|
||||
else:
|
||||
standardize_possible_answer[i].append(possible_answer[i][j])
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}."
|
||||
],
|
||||
"error_type": "value_error:list/tuple",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def dict_checker(param: str, model_output: dict, possible_answers: list):
|
||||
# This function works for simple dictionaries, but not dictionaries with nested dictionaries.
|
||||
# The current dataset only contains simple dictionaries, so this is sufficient.
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
for i in range(len(possible_answers)):
|
||||
if possible_answers[i] == "":
|
||||
continue
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
|
||||
flag = True
|
||||
|
||||
possible_answer = possible_answers[i]
|
||||
# possible_anwer is a single dictionary
|
||||
|
||||
for key, value in model_output.items():
|
||||
if key not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
standardize_value = value
|
||||
# If the value is a string, we need to standardize it
|
||||
if type(value) == str:
|
||||
standardize_value = standardize_string(value)
|
||||
|
||||
# We also need to standardize the possible answers if they are string
|
||||
standardize_possible_answer = []
|
||||
for i in range(len(possible_answer[key])):
|
||||
if type(possible_answer[key][i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
|
||||
else:
|
||||
standardize_possible_answer.append(possible_answer[key][i])
|
||||
|
||||
if standardize_value not in standardize_possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}."
|
||||
)
|
||||
result["error_type"] = "value_error:dict_value"
|
||||
flag = False
|
||||
break
|
||||
|
||||
for key, value in possible_answer.items():
|
||||
if key not in model_output and "" not in value:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_dict_checker(param: str, model_output: list, possible_answers: list):
|
||||
# This function takes in a list of dictionaries and checks if each dictionary is valid
|
||||
# The order of the dictionaries in the list must match the order of the possible answers
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"}
|
||||
|
||||
for answer_index in range(len(possible_answers)):
|
||||
flag = True # True means so far, all dictionaries are valid
|
||||
|
||||
# Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers
|
||||
if len(model_output) != len(possible_answers[answer_index]):
|
||||
result["valid"] = False
|
||||
result["error"] = ["Wrong number of dictionaries in the list."]
|
||||
result["error_type"] = "value_error:list_dict_count"
|
||||
flag = False
|
||||
continue
|
||||
|
||||
for dict_index in range(len(model_output)):
|
||||
result = dict_checker(
|
||||
param,
|
||||
model_output[dict_index],
|
||||
[possible_answers[answer_index][dict_index]],
|
||||
)
|
||||
if not result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def simple_function_checker(
|
||||
func_description: dict,
|
||||
model_output: dict,
|
||||
possible_answer: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
possible_answer = list(possible_answer.values())[0]
|
||||
# Extract function name and parameters details
|
||||
func_name = func_description["name"]
|
||||
param_details = func_description["parameters"]["properties"]
|
||||
required_params = func_description["parameters"]["required"]
|
||||
|
||||
# Initialize a result dictionary
|
||||
result = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"error_type": "simple_function_checker:unclear",
|
||||
}
|
||||
|
||||
# Check if function name matches
|
||||
if func_name not in model_output:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Function name {repr(func_name)} not found in model output."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:wrong_func_name"
|
||||
return result
|
||||
|
||||
model_params = model_output[func_name]
|
||||
|
||||
# Check for required parameters in model output
|
||||
for param in required_params:
|
||||
if param not in model_params:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:missing_required"
|
||||
return result
|
||||
|
||||
# Validate types and values for each parameter in model output
|
||||
for param, value in model_params.items():
|
||||
if param not in param_details or param not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:unexpected_param"
|
||||
return result
|
||||
|
||||
full_param_details = param_details[param]
|
||||
expected_type_description = full_param_details["type"] # This is a string
|
||||
is_variable = False
|
||||
nested_type_converted = None
|
||||
|
||||
if language == "Java":
|
||||
from evals.utils.bfcl.java_type_converter import java_type_converter
|
||||
|
||||
expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JAVA_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:java"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
|
||||
value = java_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = java_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "JavaScript":
|
||||
from evals.utils.bfcl.js_type_converter import js_type_converter
|
||||
|
||||
expected_type_converted = JS_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JS_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:js"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
|
||||
value = js_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = js_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "Python":
|
||||
expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description]
|
||||
if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = PYTHON_TYPE_MAPPING[nested_type]
|
||||
|
||||
# We convert all tuple value to list when the expected type is tuple.
|
||||
# The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load().
|
||||
# This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future.
|
||||
if expected_type_description == "tuple" and type(value) == tuple:
|
||||
value = list(value)
|
||||
|
||||
# Allow python auto conversion from int to float
|
||||
if language == "Python" and expected_type_description == "float" and type(value) == int:
|
||||
value = float(value)
|
||||
|
||||
# Type checking
|
||||
# In fact, we only check for Python here.
|
||||
# Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct.
|
||||
type_check_result = type_checker(
|
||||
param,
|
||||
value,
|
||||
possible_answer[param],
|
||||
expected_type_description,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
)
|
||||
is_variable = type_check_result["is_variable"]
|
||||
if not type_check_result["valid"]:
|
||||
return type_check_result
|
||||
|
||||
# It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable.
|
||||
# We can just treat the variable as a string and use the normal flow.
|
||||
if not is_variable:
|
||||
# Special handle for dictionaries
|
||||
if expected_type_converted == dict:
|
||||
result = dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for list of dictionaries
|
||||
elif expected_type_converted == list and nested_type_converted == dict:
|
||||
result = list_dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for strings
|
||||
elif expected_type_converted == str:
|
||||
# We don't check for case sensitivity for string, as long as it's not a variable
|
||||
result = string_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
elif expected_type_converted == list:
|
||||
result = list_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Check if the value is within the possible answers
|
||||
if value not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}."
|
||||
)
|
||||
result["error_type"] = "value_error:others"
|
||||
return result
|
||||
|
||||
# Check for optional parameters not provided but allowed
|
||||
for param in possible_answer:
|
||||
if param not in model_params and "" not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Optional parameter {repr(param)} not provided and not marked as optional."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:missing_optional"
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parallel_function_checker_enforce_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_enforce_order:wrong_count",
|
||||
}
|
||||
|
||||
func_name_list = list(possible_answers.keys())
|
||||
possible_answers_list = []
|
||||
|
||||
for key, value in possible_answers.items():
|
||||
possible_answers_list.append({key: value})
|
||||
|
||||
for i in range(len(possible_answers_list)):
|
||||
func_description = find_description(func_descriptions, func_name_list[i])
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[i],
|
||||
possible_answers_list[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
if not result["valid"]:
|
||||
return result
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def parallel_function_checker_no_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_no_order:wrong_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
|
||||
# We go throught the possible answers one by one, and eliminate the model output that matches the possible answer
|
||||
# It must be this way because we need ground truth to fetch the correct function description
|
||||
for i in range(len(possible_answers)):
|
||||
# possible_answers[i] is a dictionary with only one key
|
||||
func_name_expected = list(possible_answers[i].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
|
||||
all_errors = []
|
||||
|
||||
for index in range(len(model_output)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[index],
|
||||
possible_answers[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_output_item": model_output[index],
|
||||
"possible_answer_item": possible_answers[i],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "parallel_function_checker_no_order:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def multiple_function_checker(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "multiple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
# possible_answers is a list of only one dictionary with only one key
|
||||
func_name_expected = list(possible_answers[0].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
return simple_function_checker(
|
||||
func_description,
|
||||
model_output[0],
|
||||
possible_answers[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def patten_matcher(exec_output, expected_result, function_call, is_sanity_check):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
if type(exec_output) != type(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == dict:
|
||||
# We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one.
|
||||
# This happens when the key is a timestamp or a random number.
|
||||
if is_sanity_check:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
for key, value in expected_result.items():
|
||||
if key not in exec_output:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_key_not_found",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
for key, value in exec_output.items():
|
||||
if key not in expected_result:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_extra_key",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == list:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:list_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
#### Helper functions for Exec ####
|
||||
def executable_checker_simple(
|
||||
function_call: str,
|
||||
expected_result,
|
||||
expected_result_type: str,
|
||||
is_sanity_check=False,
|
||||
):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
exec_dict: Any = {}
|
||||
|
||||
try:
|
||||
exec(
|
||||
"from executable_python_function import *" + "\nresult=" + function_call,
|
||||
exec_dict,
|
||||
)
|
||||
exec_output = exec_dict["result"]
|
||||
except NoAPIKeyError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Error in execution: {repr(function_call)}. Error: {str(e)}"
|
||||
)
|
||||
result["error_type"] = "executable_checker:execution_error"
|
||||
return result
|
||||
|
||||
# We need to special handle the case where the execution result is a tuple and convert it to a list
|
||||
# Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json
|
||||
if isinstance(exec_output, tuple):
|
||||
exec_output = list(exec_output)
|
||||
|
||||
if expected_result_type == "exact_match":
|
||||
if exec_output != expected_result:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
elif expected_result_type == "real_time_match":
|
||||
# Allow for 5% difference
|
||||
if (type(expected_result) == float or type(expected_result) == int) and (
|
||||
type(exec_output) == float or type(exec_output) == int
|
||||
):
|
||||
if not (
|
||||
expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
<= exec_output
|
||||
<= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
):
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
else:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
else:
|
||||
# structural match
|
||||
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
|
||||
if not pattern_match_result["valid"]:
|
||||
return pattern_match_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def executable_checker_parallel_no_order(
|
||||
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
|
||||
):
|
||||
if len(decoded_result) != len(expected_exec_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}."
|
||||
],
|
||||
"error_type": "value_error:exec_result_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
for i in range(len(expected_exec_result)):
|
||||
all_errors = []
|
||||
for index in range(len(decoded_result)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = executable_checker_simple(
|
||||
decoded_result[index],
|
||||
expected_exec_result[i],
|
||||
expected_exec_result_type[i],
|
||||
False,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_executed_output": (
|
||||
result["model_executed_output"] if "model_executed_output" in result else None
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "executable_checker:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
|
||||
#### Main function ####
|
||||
def executable_checker_rest(func_call, idx):
|
||||
# Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used.
|
||||
EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution
|
||||
with open(EVAL_GROUND_TRUTH_PATH, "r") as f:
|
||||
EVAL_GROUND_TRUTH = f.readlines()
|
||||
if "https://geocode.maps.co" in func_call:
|
||||
time.sleep(2)
|
||||
if "requests_get" in func_call:
|
||||
func_call = func_call.replace("requests_get", "requests.get")
|
||||
try:
|
||||
response = eval(func_call)
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution failed. {str(e)}"],
|
||||
"error_type": "executable_checker_rest:execution_error",
|
||||
}
|
||||
|
||||
try:
|
||||
if response.status_code == 200:
|
||||
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
|
||||
try:
|
||||
if isinstance(eval_GT_json, dict):
|
||||
if isinstance(response.json(), dict):
|
||||
if set(eval_GT_json.keys()) == set(response.json().keys()):
|
||||
return {"valid": True, "error": [], "error_type": ""}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dictionary, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
|
||||
elif isinstance(eval_GT_json, list):
|
||||
if isinstance(response.json(), list):
|
||||
if len(eval_GT_json) != len(response.json()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Response list length inconsistency."],
|
||||
"error_type": "value_error:exec_result_rest_count",
|
||||
}
|
||||
|
||||
else:
|
||||
for i in range(len(eval_GT_json)):
|
||||
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dict or list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}"
|
||||
],
|
||||
"error_type": "executable_checker_rest:response_format_error",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution result status code is not 200, got {response.status_code}"],
|
||||
"error_type": "executable_checker_rest:wrong_status_code",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Cannot get status code of the response. Error: {str(e)}"],
|
||||
"error_type": "executable_checker_rest:cannot_get_status_code",
|
||||
}
|
||||
|
||||
|
||||
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
|
||||
if "parallel" in test_category:
|
||||
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
elif "multiple" in test_category:
|
||||
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
else:
|
||||
if len(model_output) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
return simple_function_checker(
|
||||
func_description[0],
|
||||
model_output[0],
|
||||
possible_answer[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def exec_checker(decoded_result: list, func_description: dict, test_category: str):
|
||||
if "multiple" in test_category or "parallel" in test_category:
|
||||
return executable_checker_parallel_no_order(
|
||||
decoded_result,
|
||||
func_description["execution_result"],
|
||||
func_description["execution_result_type"],
|
||||
)
|
||||
|
||||
else:
|
||||
if len(decoded_result) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_exec_checker:wrong_count",
|
||||
}
|
||||
return executable_checker_simple(
|
||||
decoded_result[0],
|
||||
func_description["execution_result"][0],
|
||||
func_description["execution_result_type"][0],
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
def is_empty_output(decoded_output):
|
||||
# This function is a patch to the ast decoder for relevance detection
|
||||
# Sometimes the ast decoder will parse successfully, but the input doens't really have a function call
|
||||
# [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct)
|
||||
if not is_function_calling_format_output(decoded_output):
|
||||
return True
|
||||
if len(decoded_output) == 0:
|
||||
return True
|
||||
if len(decoded_output) == 1 and len(decoded_output[0]) == 0:
|
||||
return True
|
||||
|
||||
|
||||
def is_function_calling_format_output(decoded_output):
|
||||
# Ensure the output is a list of dictionaries
|
||||
if type(decoded_output) == list:
|
||||
for item in decoded_output:
|
||||
if type(item) != dict:
|
||||
return False
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
|
||||
import it from here so that we can centrally manage things as necessary.
|
||||
"""
|
||||
|
||||
# These currently work with tree-sitter 0.23.0
|
||||
# NOTE: Don't import tree-sitter or any of the language modules in the main module
|
||||
# because not all environments have them. Import lazily inside functions where needed.
|
||||
|
||||
import importlib
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import tree_sitter
|
||||
|
||||
|
||||
def get_language(language: str) -> "tree_sitter.Language":
|
||||
import tree_sitter
|
||||
|
||||
language_module_name = f"tree_sitter_{language}"
|
||||
try:
|
||||
language_module = importlib.import_module(language_module_name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
f"Language {language} is not found. Please install the tree-sitter-{language} package."
|
||||
) from exc
|
||||
return tree_sitter.Language(language_module.language())
|
||||
|
||||
|
||||
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
|
||||
import tree_sitter
|
||||
|
||||
lang = get_language(language)
|
||||
return tree_sitter.Parser(lang, **kwargs)
|
|
@ -3,7 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlmAsJudgeScoringConfig(BaseModel): ...
|
||||
class LlmAsJudgeScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
|
@ -34,7 +34,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
print(
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
|
@ -46,7 +46,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
span_context = (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
|
@ -74,7 +74,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
|
@ -124,8 +124,8 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
trace_id,
|
||||
service_name,
|
||||
(span_id if not parent_span_id else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -143,8 +143,8 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
trace_id,
|
||||
parent_span_id,
|
||||
span.name,
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
json.dumps(dict(span.attributes)),
|
||||
span.status.status_code.name,
|
||||
span.kind.name,
|
||||
|
@ -161,7 +161,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
(
|
||||
span_id,
|
||||
event.name,
|
||||
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(event.timestamp / 1e9, timezone.utc).isoformat(),
|
||||
json.dumps(dict(event.attributes)),
|
||||
),
|
||||
)
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleTelemetryImpl
|
||||
|
||||
impl = SampleTelemetryImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleTelemetryImpl(Telemetry):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
|
@ -76,6 +76,7 @@ class CodeExecutionRequest:
|
|||
only_last_cell_fail: bool = True
|
||||
seed: int = 0
|
||||
strip_fpaths_in_stderr: bool = True
|
||||
use_bwrap: bool = True
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
|
@ -103,8 +104,6 @@ _set_seeds()\
|
|||
|
||||
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
|
||||
with tempfile.TemporaryDirectory() as dpath:
|
||||
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||
code_fpath = os.path.join(dpath, "code.py")
|
||||
with open(code_fpath, "w") as f:
|
||||
f.write(script)
|
||||
|
@ -118,6 +117,13 @@ _set_seeds()\
|
|||
MPLBACKEND="module://matplotlib_custom_backend",
|
||||
PYTHONPATH=f"{DIRNAME}:{python_path}",
|
||||
)
|
||||
|
||||
if req.use_bwrap:
|
||||
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||
else:
|
||||
cmd = [sys.executable, "-c", script]
|
||||
|
||||
stdout, stderr, returncode = do_subprocess(
|
||||
cmd=cmd,
|
||||
env=env,
|
||||
|
@ -162,7 +168,7 @@ def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
|||
image_paths = []
|
||||
for i, img in enumerate(images):
|
||||
# create new directory for each day to better organize data:
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d")
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d") # noqa: DTZ002 - we don't care about timezones here since we are displaying the date
|
||||
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||
# save image into a file
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -61,7 +62,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
script = kwargs["code"]
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
# Use environment variable to control bwrap usage
|
||||
force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes")
|
||||
req = CodeExecutionRequest(scripts=[script], use_bwrap=not force_disable_bwrap)
|
||||
res = self.code_executor.execute(req)
|
||||
pieces = [res["process_status"]]
|
||||
for out_type in ["stdout", "stderr"]:
|
||||
|
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterToolConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RagToolRuntimeConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
return {"db_path": "{env.CHROMADB_PATH}"}
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"db_path": db_path}
|
||||
|
|
|
@ -7,11 +7,9 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_dependencies
|
||||
|
||||
|
@ -39,13 +37,4 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.tool_groups,
|
||||
],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.agents,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.agents.sample",
|
||||
config_class="llama_stack.providers.remote.agents.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -14,7 +14,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.eval,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[],
|
||||
pip_packages=["tree_sitter"],
|
||||
module="llama_stack.providers.inline.eval.meta_reference",
|
||||
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -68,15 +68,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.sample",
|
||||
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -27,27 +27,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.safety.prompt_guard",
|
||||
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[
|
||||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
module="llama_stack.providers.inline.safety.meta_reference",
|
||||
config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
deprecation_error="""
|
||||
Provider `inline::meta-reference` for API `safety` does not work with the latest Llama Stack.
|
||||
|
||||
- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead.
|
||||
- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead.
|
||||
- if you are using Code Scanner, please use the `inline::code-scanner` provider instead.
|
||||
|
||||
""",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="inline::llama-guard",
|
||||
|
@ -67,15 +46,6 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest
|
|||
module="llama_stack.providers.inline.safety.code_scanner",
|
||||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.safety.sample",
|
||||
config_class="llama_stack.providers.remote.safety.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -7,11 +7,9 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -28,13 +26,4 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.telemetry.meta_reference",
|
||||
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.telemetry,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.telemetry.sample",
|
||||
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -92,16 +92,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.vector_io,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.vector_io.sample",
|
||||
config_class="llama_stack.providers.remote.vector_io.sample.SampleVectorIOConfig",
|
||||
),
|
||||
api_dependencies=[],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleAgentsImpl
|
||||
|
||||
impl = SampleAgentsImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleAgentsImpl(Agents):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class HuggingfaceDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "huggingface_datasetio.db").as_posix()
|
||||
) # Uses SQLite config specific to HF storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="huggingface_datasetio.db",
|
||||
)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -20,3 +21,15 @@ class DatabricksImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN}",
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
||||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = RunpodInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -21,3 +21,10 @@ class RunpodImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.RUNPOD_URL:}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN:}",
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator
|
|||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.models.llama.datatypes import Message
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleInferenceImpl
|
||||
|
||||
impl = SampleInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleInferenceImpl(Inference):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleSafetyImpl
|
||||
|
||||
impl = SampleSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleSafetyImpl(Safety):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
# these are the safety shields the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -14,3 +14,9 @@ class BingSearchToolConfig(BaseModel):
|
|||
|
||||
api_key: Optional[str] = None
|
||||
top_k: int = 3
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.BING_API_KEY:}",
|
||||
}
|
||||
|
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelContextProtocolConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -13,3 +13,9 @@ class WolframAlphaToolConfig(BaseModel):
|
|||
"""Configuration for WolframAlpha Tool Runtime"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.WOLFRAM_ALPHA_API_KEY:}",
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -24,3 +24,9 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
timeout: Optional[int] = None
|
||||
host: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.QDRANT_API_KEY}",
|
||||
}
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleVectorIOConfig, _deps) -> Any:
|
||||
from .sample import SampleVectorIOImpl
|
||||
|
||||
impl = SampleVectorIOImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleVectorIOConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
|
||||
from .config import SampleVectorIOConfig
|
||||
|
||||
|
||||
class SampleVectorIOImpl(VectorIO):
|
||||
def __init__(self, config: SampleVectorIOConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
# these are the vector dbs the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
@ -13,4 +15,6 @@ class WeaviateRequestProviderData(BaseModel):
|
|||
|
||||
|
||||
class WeaviateVectorIOConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -23,6 +23,10 @@ class ColumnName(Enum):
|
|||
generated_answer = "generated_answer"
|
||||
context = "context"
|
||||
dialog = "dialog"
|
||||
function = "function"
|
||||
language = "language"
|
||||
id = "id"
|
||||
ground_truth = "ground_truth"
|
||||
|
||||
|
||||
VALID_SCHEMAS_FOR_SCORING = [
|
||||
|
@ -37,6 +41,15 @@ VALID_SCHEMAS_FOR_SCORING = [
|
|||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.context.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.function.value: StringType(),
|
||||
ColumnName.language.value: StringType(),
|
||||
ColumnName.id.value: StringType(),
|
||||
ColumnName.ground_truth.value: StringType(),
|
||||
},
|
||||
]
|
||||
|
||||
VALID_SCHEMAS_FOR_EVAL = [
|
||||
|
@ -50,6 +63,15 @@ VALID_SCHEMAS_FOR_EVAL = [
|
|||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.completion_input.value: CompletionInputType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.function.value: StringType(),
|
||||
ColumnName.language.value: StringType(),
|
||||
ColumnName.id.value: StringType(),
|
||||
ColumnName.ground_truth.value: StringType(),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import logging
|
|||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
@ -86,7 +86,7 @@ class TraceContext:
|
|||
span_id=generate_short_uuid(),
|
||||
trace_id=self.trace_id,
|
||||
name=name,
|
||||
start_time=datetime.now(),
|
||||
start_time=datetime.now(timezone.utc),
|
||||
parent_span_id=current_span.span_id if current_span else None,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
@ -203,7 +203,7 @@ class TelemetryHandler(logging.Handler):
|
|||
UnstructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=datetime.now(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
message=self.format(record),
|
||||
severity=severity(record.levelname),
|
||||
)
|
||||
|
|
|
@ -45,14 +45,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
@ -23,7 +23,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
@ -43,14 +44,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
@ -28,7 +28,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -47,14 +48,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
@ -31,7 +31,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -50,14 +51,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
@ -27,7 +27,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -46,14 +47,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue