mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Merge branch 'meta-llama:main' into feat/litellm_sambanova_usage
This commit is contained in:
commit
13c660f5a5
57 changed files with 10986 additions and 93 deletions
|
@ -320,7 +320,7 @@ jobs:
|
|||
- name: "PR - Update comment"
|
||||
id: pr_update_comment
|
||||
if: github.event_name == 'pull_request_target'
|
||||
uses: thollander/actions-comment-pull-request@65f9e5c9a1f2cd378bd74b2e057c9736982a8e74 # v3.0.1
|
||||
uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b # v3.0.1
|
||||
with:
|
||||
filePath: test-summary.md
|
||||
|
||||
|
|
93
.github/workflows/test-external-providers.yml
vendored
Normal file
93
.github/workflows/test-external-providers.yml
vendored
Normal file
|
@ -0,0 +1,93 @@
|
|||
name: Test External Providers
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test-external-providers:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install Ollama
|
||||
run: |
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
- name: Pull Ollama image
|
||||
run: |
|
||||
ollama pull llama3.2:3b-instruct-fp16
|
||||
|
||||
- name: Start Ollama in background
|
||||
run: |
|
||||
nohup ollama run llama3.2:3b-instruct-fp16 --keepalive=30m > ollama.log 2>&1 &
|
||||
|
||||
- name: Set Up Environment and Install Dependencies
|
||||
run: |
|
||||
uv sync --extra dev --extra test
|
||||
uv pip install -e .
|
||||
|
||||
- name: Install Ollama custom provider
|
||||
run: |
|
||||
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
|
||||
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
|
||||
uv pip install tests/external-provider/llama-stack-provider-ollama
|
||||
|
||||
- name: Create provider configuration
|
||||
run: |
|
||||
mkdir -p /tmp/providers.d/remote/inference
|
||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
||||
|
||||
- name: Wait for Ollama to start
|
||||
run: |
|
||||
echo "Waiting for Ollama..."
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
||||
echo "Ollama is running!"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Ollama failed to start"
|
||||
ollama ps
|
||||
ollama.log
|
||||
exit 1
|
||||
|
||||
- name: Start Llama Stack server in background
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
echo "Waiting for Llama Stack server..."
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
|
||||
echo "Llama Stack server is up!"
|
||||
if grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||
echo "Llama Stack server is using custom Ollama provider"
|
||||
exit 0
|
||||
else
|
||||
echo "Llama Stack server is not using custom Ollama provider"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Llama Stack server failed to start"
|
||||
cat server.log
|
||||
exit 1
|
||||
|
||||
- name: run inference tests
|
||||
run: |
|
||||
uv run pytest -v tests/integration/inference/test_text_inference.py --stack-config="http://localhost:8321" --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2
|
37
CHANGELOG.md
37
CHANGELOG.md
|
@ -1,5 +1,42 @@
|
|||
# Changelog
|
||||
|
||||
# v0.2.1
|
||||
Published on: 2025-04-05T23:13:00Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.0
|
||||
Published on: 2025-04-05T19:04:29Z
|
||||
|
||||
## Llama 4 Support
|
||||
|
||||
Checkout more at https://www.llama.com
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.9
|
||||
Published on: 2025-03-29T00:52:23Z
|
||||
|
||||
### Build and Test Agents
|
||||
* Agents: Entire document context with attachments
|
||||
* RAG: Documentation with sqlite-vec faiss comparison
|
||||
* Getting started: Fixes to getting started notebook.
|
||||
|
||||
### Agent Evals and Model Customization
|
||||
* (**New**) Post-training: Add nemo customizer
|
||||
|
||||
### Better Engineering
|
||||
* Moved sqlite-vec to non-blocking calls
|
||||
* Don't return a payload on file delete
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.8
|
||||
Published on: 2025-03-24T01:28:50Z
|
||||
|
||||
|
|
9
docs/_static/js/detect_theme.js
vendored
Normal file
9
docs/_static/js/detect_theme.js
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
document.addEventListener("DOMContentLoaded", function () {
|
||||
const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches;
|
||||
const htmlElement = document.documentElement;
|
||||
if (prefersDark) {
|
||||
htmlElement.setAttribute("data-theme", "dark");
|
||||
} else {
|
||||
htmlElement.setAttribute("data-theme", "light");
|
||||
}
|
||||
});
|
|
@ -112,6 +112,8 @@ html_theme_options = {
|
|||
# "style_nav_header_background": "#c3c9d4",
|
||||
}
|
||||
|
||||
default_dark_mode = False
|
||||
|
||||
html_static_path = ["../_static"]
|
||||
# html_logo = "../_static/llama-stack-logo.png"
|
||||
# html_style = "../_static/css/my_theme.css"
|
||||
|
@ -119,6 +121,7 @@ html_static_path = ["../_static"]
|
|||
|
||||
def setup(app):
|
||||
app.add_css_file("css/my_theme.css")
|
||||
app.add_js_file("js/detect_theme.js")
|
||||
|
||||
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
|
||||
url = f"https://hub.docker.com/r/llamastack/{text}"
|
||||
|
|
|
@ -7,13 +7,13 @@ In this guide, we'll use a local [Kind](https://kind.sigs.k8s.io/) cluster and a
|
|||
|
||||
First, create a local Kubernetes cluster via Kind:
|
||||
|
||||
```bash
|
||||
```
|
||||
kind create cluster --image kindest/node:v1.32.0 --name llama-stack-test
|
||||
```
|
||||
|
||||
First, create a Kubernetes PVC and Secret for downloading and storing Hugging Face model:
|
||||
|
||||
```bash
|
||||
```
|
||||
cat <<EOF |kubectl apply -f -
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
|
@ -39,7 +39,7 @@ data:
|
|||
|
||||
Next, start the vLLM server as a Kubernetes Deployment and Service:
|
||||
|
||||
```bash
|
||||
```
|
||||
cat <<EOF |kubectl apply -f -
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
|
@ -95,7 +95,7 @@ EOF
|
|||
|
||||
We can verify that the vLLM server has started successfully via the logs (this might take a couple of minutes to download the model):
|
||||
|
||||
```bash
|
||||
```
|
||||
$ kubectl logs -l app.kubernetes.io/name=vllm
|
||||
...
|
||||
INFO: Started server process [1]
|
||||
|
@ -119,7 +119,7 @@ providers:
|
|||
|
||||
Once we have defined the run configuration for Llama Stack, we can build an image with that configuration and the server source code:
|
||||
|
||||
```bash
|
||||
```
|
||||
cat >/tmp/test-vllm-llama-stack/Containerfile.llama-stack-run-k8s <<EOF
|
||||
FROM distribution-myenv:dev
|
||||
|
||||
|
@ -135,7 +135,7 @@ podman build -f /tmp/test-vllm-llama-stack/Containerfile.llama-stack-run-k8s -t
|
|||
|
||||
We can then start the Llama Stack server by deploying a Kubernetes Pod and Service:
|
||||
|
||||
```bash
|
||||
```
|
||||
cat <<EOF |kubectl apply -f -
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
|
@ -195,7 +195,7 @@ EOF
|
|||
### Verifying the Deployment
|
||||
We can check that the LlamaStack server has started:
|
||||
|
||||
```bash
|
||||
```
|
||||
$ kubectl logs -l app.kubernetes.io/name=llama-stack
|
||||
...
|
||||
INFO: Started server process [1]
|
||||
|
@ -207,7 +207,7 @@ INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit
|
|||
|
||||
Finally, we forward the Kubernetes service to a local port and test some inference requests against it via the Llama Stack Client:
|
||||
|
||||
```bash
|
||||
```
|
||||
kubectl port-forward service/llama-stack-service 5000:5000
|
||||
llama-stack-client --endpoint http://localhost:5000 inference chat-completion --message "hello, what model are you?"
|
||||
```
|
||||
|
|
|
@ -25,7 +25,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
|||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
|
||||
You can use this distribution if you want to run an independent vLLM server for inference.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
|
@ -41,6 +41,83 @@ The following environment variables can be configured:
|
|||
|
||||
## Setting up vLLM server
|
||||
|
||||
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
|
||||
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||
that we only use GPUs here for demonstration purposes.
|
||||
|
||||
### Setting up vLLM server on AMD GPU
|
||||
|
||||
AMD provides two main vLLM container options:
|
||||
- rocm/vllm: Production-ready container
|
||||
- rocm/vllm-dev: Development container with the latest vLLM features
|
||||
|
||||
Please check the [Blog about ROCm vLLM Usage](https://rocm.blogs.amd.com/software-tools-optimization/vllm-container/README.html) to get more details.
|
||||
|
||||
Here is a sample script to start a ROCm vLLM server locally via Docker:
|
||||
|
||||
```bash
|
||||
export INFERENCE_PORT=8000
|
||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export VLLM_DIMG="rocm/vllm-dev:main"
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--ipc=host \
|
||||
--privileged \
|
||||
--shm-size 16g \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
--group-add video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--cap-add=CAP_SYS_ADMIN \
|
||||
--security-opt seccomp=unconfined \
|
||||
--security-opt apparmor=unconfined \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
|
||||
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
$VLLM_DIMG \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model $INFERENCE_MODEL \
|
||||
--port $INFERENCE_PORT
|
||||
```
|
||||
|
||||
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||
|
||||
```bash
|
||||
export SAFETY_PORT=8081
|
||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export VLLM_DIMG="rocm/vllm-dev:main"
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--ipc=host \
|
||||
--privileged \
|
||||
--shm-size 16g \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
--group-add video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--cap-add=CAP_SYS_ADMIN \
|
||||
--security-opt seccomp=unconfined \
|
||||
--security-opt apparmor=unconfined \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
|
||||
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
$VLLM_DIMG \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model $SAFETY_MODEL \
|
||||
--port $SAFETY_PORT
|
||||
```
|
||||
|
||||
### Setting up vLLM server on NVIDIA GPU
|
||||
|
||||
Please check the [vLLM Documentation](https://docs.vllm.ai/en/v0.5.5/serving/deploying_with_docker.html) to get a vLLM endpoint. Here is a sample script to start a vLLM server locally via Docker:
|
||||
|
||||
```bash
|
||||
|
|
|
@ -6,13 +6,13 @@ Llama Stack is a stateful service with REST APIs to support seamless transition
|
|||
In this guide, we'll walk through how to build a RAG agent locally using Llama Stack with [Ollama](https://ollama.com/) to run inference on a Llama Model.
|
||||
|
||||
|
||||
### 1. Start Ollama
|
||||
### 1. Download a Llama model with Ollama
|
||||
|
||||
```bash
|
||||
ollama run llama3.2:3b --keepalive 60m
|
||||
ollama pull llama3.2:3b-instruct-fp16
|
||||
```
|
||||
|
||||
By default, Ollama keeps the model loaded in memory for 5 minutes which can be too short. We set the `--keepalive` flag to 60 minutes to ensure the model remains loaded for sometime.
|
||||
This will instruct the Ollama service to download the Llama 3.2 3B Instruct model, which we'll use in the rest of this guide.
|
||||
|
||||
```{admonition} Note
|
||||
:class: tip
|
||||
|
|
|
@ -103,7 +103,5 @@ llama stack run together
|
|||
|
||||
2. Start Streamlit UI
|
||||
```bash
|
||||
cd llama_stack/distribution/ui
|
||||
pip install -r requirements.txt
|
||||
streamlit run app.py
|
||||
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
|
||||
```
|
||||
|
|
234
docs/source/providers/external.md
Normal file
234
docs/source/providers/external.md
Normal file
|
@ -0,0 +1,234 @@
|
|||
# External Providers
|
||||
|
||||
Llama Stack supports external providers that live outside of the main codebase. This allows you to:
|
||||
- Create and maintain your own providers independently
|
||||
- Share providers with others without contributing to the main codebase
|
||||
- Keep provider-specific code separate from the core Llama Stack code
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
|
||||
|
||||
```yaml
|
||||
external_providers_dir: /etc/llama-stack/providers.d/
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
The external providers directory should follow this structure:
|
||||
|
||||
```
|
||||
providers.d/
|
||||
remote/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
inline/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
```
|
||||
|
||||
Each YAML file in these directories defines a provider specification for that particular API.
|
||||
|
||||
## Provider Types
|
||||
|
||||
Llama Stack supports two types of external providers:
|
||||
|
||||
1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs)
|
||||
2. **Inline Providers**: Providers that run locally within the Llama Stack process
|
||||
|
||||
## Known External Providers
|
||||
|
||||
Here's a list of known external providers that you can use with Llama Stack:
|
||||
|
||||
| Type | Name | Description | Repository |
|
||||
|------|------|-------------|------------|
|
||||
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||
|
||||
### Remote Provider Specification
|
||||
|
||||
Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider:
|
||||
|
||||
```yaml
|
||||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages:
|
||||
- ollama
|
||||
- aiohttp
|
||||
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
|
||||
module: llama_stack_ollama_provider
|
||||
api_dependencies: []
|
||||
optional_api_dependencies: []
|
||||
```
|
||||
|
||||
#### Adapter Configuration
|
||||
|
||||
The `adapter` section defines how to load and configure the provider:
|
||||
|
||||
- `adapter_type`: A unique identifier for this adapter
|
||||
- `pip_packages`: List of Python packages required by the provider
|
||||
- `config_class`: The full path to the configuration class
|
||||
- `module`: The Python module containing the provider implementation
|
||||
|
||||
### Inline Provider Specification
|
||||
|
||||
Inline providers run locally within the Llama Stack process. Here's an example for a custom vector store provider:
|
||||
|
||||
```yaml
|
||||
module: llama_stack_vector_provider
|
||||
config_class: llama_stack_vector_provider.config.VectorStoreConfig
|
||||
pip_packages:
|
||||
- faiss-cpu
|
||||
- numpy
|
||||
api_dependencies:
|
||||
- inference
|
||||
optional_api_dependencies:
|
||||
- vector_io
|
||||
provider_data_validator: llama_stack_vector_provider.validator.VectorStoreValidator
|
||||
container_image: custom-vector-store:latest # optional
|
||||
```
|
||||
|
||||
#### Inline Provider Fields
|
||||
|
||||
- `module`: The Python module containing the provider implementation
|
||||
- `config_class`: The full path to the configuration class
|
||||
- `pip_packages`: List of Python packages required by the provider
|
||||
- `api_dependencies`: List of Llama Stack APIs that this provider depends on
|
||||
- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use
|
||||
- `provider_data_validator`: Optional validator for provider data
|
||||
- `container_image`: Optional container image to use instead of pip packages
|
||||
|
||||
## Required Implementation
|
||||
|
||||
### Remote Providers
|
||||
|
||||
Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments:
|
||||
1. `config`: An instance of the provider's config class
|
||||
2. `deps`: A dictionary of API dependencies
|
||||
|
||||
This function must return an instance of the provider's adapter class that implements the required protocol for the API.
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def get_adapter_impl(
|
||||
config: OllamaImplConfig, deps: Dict[Api, Any]
|
||||
) -> OllamaInferenceAdapter:
|
||||
return OllamaInferenceAdapter(config)
|
||||
```
|
||||
|
||||
### Inline Providers
|
||||
|
||||
Inline providers must expose a `get_provider_impl()` function in their module that takes two arguments:
|
||||
1. `config`: An instance of the provider's config class
|
||||
2. `deps`: A dictionary of API dependencies
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def get_provider_impl(
|
||||
config: VectorStoreConfig, deps: Dict[Api, Any]
|
||||
) -> VectorStoreImpl:
|
||||
impl = VectorStoreImpl(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
The provider package must be installed on the system. For example:
|
||||
|
||||
```bash
|
||||
$ uv pip show llama-stack-ollama-provider
|
||||
Name: llama-stack-ollama-provider
|
||||
Version: 0.1.0
|
||||
Location: /path/to/venv/lib/python3.10/site-packages
|
||||
```
|
||||
|
||||
## Example: Custom Ollama Provider
|
||||
|
||||
Here's a complete example of creating and using a custom Ollama provider:
|
||||
|
||||
1. First, create the provider package:
|
||||
|
||||
```bash
|
||||
mkdir -p llama-stack-provider-ollama
|
||||
cd llama-stack-provider-ollama
|
||||
git init
|
||||
uv init
|
||||
```
|
||||
|
||||
2. Edit `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "llama-stack-provider-ollama"
|
||||
version = "0.1.0"
|
||||
description = "Ollama provider for Llama Stack"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
|
||||
```
|
||||
|
||||
3. Create the provider specification:
|
||||
|
||||
```yaml
|
||||
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
|
||||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages: ["ollama", "aiohttp"]
|
||||
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
||||
module: llama_stack_provider_ollama
|
||||
api_dependencies: []
|
||||
optional_api_dependencies: []
|
||||
```
|
||||
|
||||
4. Install the provider:
|
||||
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
5. Configure Llama Stack to use external providers:
|
||||
|
||||
```yaml
|
||||
external_providers_dir: /etc/llama-stack/providers.d/
|
||||
```
|
||||
|
||||
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable.
|
||||
|
||||
2. **Version Management**: Keep your provider package versioned and compatible with the Llama Stack version you're using.
|
||||
|
||||
3. **Dependencies**: Only include the minimum required dependencies in your provider package.
|
||||
|
||||
4. **Documentation**: Include clear documentation in your provider package about:
|
||||
- Installation requirements
|
||||
- Configuration options
|
||||
- Usage examples
|
||||
- Any limitations or known issues
|
||||
|
||||
5. **Testing**: Include tests in your provider package to ensure it works correctly with Llama Stack.
|
||||
You can refer to the [integration tests
|
||||
guide](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more
|
||||
information. Execute the test for the Provider type you are developing.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If your external provider isn't being loaded:
|
||||
|
||||
1. Check that the `external_providers_dir` path is correct and accessible.
|
||||
2. Verify that the YAML files are properly formatted.
|
||||
3. Ensure all required Python packages are installed.
|
||||
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more
|
||||
information using `LLAMA_STACK_LOGGING=all=debug`.
|
||||
5. Verify that the provider package is installed in your Python environment.
|
|
@ -11,6 +11,10 @@ Providers come in two flavors:
|
|||
|
||||
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||
|
||||
## External Providers
|
||||
|
||||
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently. See the [External Providers Guide](external) for details.
|
||||
|
||||
## Agents
|
||||
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
||||
|
||||
|
@ -50,6 +54,7 @@ The following providers (i.e., databases) are available for Vector IO:
|
|||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
external
|
||||
vector_io/faiss
|
||||
vector_io/sqlite-vec
|
||||
vector_io/chromadb
|
||||
|
|
|
@ -312,6 +312,11 @@ a default SQLite store will be used.""",
|
|||
description="Configuration for the HTTP(S) server",
|
||||
)
|
||||
|
||||
external_providers_dir: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
|
|
@ -4,12 +4,25 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
from typing import Dict, List
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
|
@ -59,11 +72,116 @@ def providable_apis() -> List[Api]:
|
|||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
|
||||
|
||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
ret = {}
|
||||
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
|
||||
adapter = AdapterSpec(**spec_data["adapter"])
|
||||
spec = remote_provider_spec(
|
||||
api=api,
|
||||
adapter=adapter,
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||
spec = InlineProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"inline::{provider_name}",
|
||||
pip_packages=spec_data.get("pip_packages", []),
|
||||
module=spec_data["module"],
|
||||
config_class=spec_data["config_class"],
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
||||
provider_data_validator=spec_data.get("provider_data_validator"),
|
||||
container_image=spec_data.get("container_image"),
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
This function loads both built-in providers and external providers from YAML files.
|
||||
External providers are loaded from a directory structure like:
|
||||
|
||||
providers.d/
|
||||
remote/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
inline/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
|
||||
Args:
|
||||
config: Optional StackRunConfig containing the external providers directory path
|
||||
|
||||
Returns:
|
||||
A dictionary mapping APIs to their available providers
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the external providers directory doesn't exist
|
||||
ValueError: If any provider spec is invalid
|
||||
"""
|
||||
|
||||
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
|
||||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
logger.debug(f"Importing module {name}")
|
||||
try:
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
if config and config.external_providers_dir:
|
||||
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
|
||||
for api in providable_apis():
|
||||
api_name = api.name.lower()
|
||||
|
||||
# Process both remote and inline providers
|
||||
for provider_type in ["remote", "inline"]:
|
||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||
if not os.path.exists(api_dir):
|
||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
continue
|
||||
|
||||
# Look for provider spec files in the API directory
|
||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
|
||||
try:
|
||||
with open(spec_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
if provider_type == "remote":
|
||||
spec = _load_remote_provider_spec(spec_data, api)
|
||||
provider_type_key = f"remote::{provider_name}"
|
||||
else:
|
||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in ret[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
ret[api][provider_type_key] = spec
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
return ret
|
||||
|
|
|
@ -351,6 +351,7 @@ async def instantiate_provider(
|
|||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
|
|
|
@ -608,8 +608,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
tools = (await self.list_tools(toolgroup_id)).data
|
||||
for tool in tools:
|
||||
tools = await self.list_tools(toolgroup_id)
|
||||
for tool in getattr(tools, "data", []):
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
||||
|
|
|
@ -218,7 +218,7 @@ async def construct_stack(
|
|||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# More info on playground configuration can be found here:
|
||||
# https://llama-stack.readthedocs.io/en/latest/playground
|
||||
|
||||
FROM python:3.9-slim
|
||||
FROM python:3.12-slim
|
||||
WORKDIR /app
|
||||
COPY . /app/
|
||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||
|
|
|
@ -36,9 +36,7 @@ llama-stack-client benchmarks register \
|
|||
3. Start Streamlit UI
|
||||
|
||||
```bash
|
||||
cd llama_stack/distribution/ui
|
||||
pip install -r requirements.txt
|
||||
streamlit run app.py
|
||||
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
|
|
@ -24,6 +24,7 @@ def main():
|
|||
# Playground pages
|
||||
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
|
||||
|
||||
# Distribution pages
|
||||
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
||||
|
@ -39,6 +40,7 @@ def main():
|
|||
"Playground": [
|
||||
chat_page,
|
||||
rag_page,
|
||||
tool_page,
|
||||
application_evaluation_page,
|
||||
native_evaluation_page,
|
||||
],
|
||||
|
|
|
@ -19,6 +19,7 @@ class LlamaStackApi:
|
|||
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
|
||||
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
|
||||
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
|
||||
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
116
llama_stack/distribution/ui/page/playground/tools.py
Normal file
116
llama_stack/distribution/ui/page/playground/tools.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def tool_chat_page():
|
||||
st.title("🛠 Tools")
|
||||
|
||||
client = llama_stack_api.client
|
||||
models = client.models.list()
|
||||
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
|
||||
|
||||
tool_groups = client.toolgroups.list()
|
||||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||
|
||||
def reset_agent():
|
||||
st.session_state.clear()
|
||||
st.cache_resource.clear()
|
||||
|
||||
with st.sidebar:
|
||||
st.subheader("Model")
|
||||
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
|
||||
|
||||
st.subheader("Builtin Tools")
|
||||
toolgroup_selection = st.pills(
|
||||
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
|
||||
)
|
||||
|
||||
st.subheader("MCP Servers")
|
||||
mcp_selection = st.pills(
|
||||
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
|
||||
)
|
||||
|
||||
toolgroup_selection.extend(mcp_selection)
|
||||
|
||||
active_tool_list = []
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
active_tool_list.extend(
|
||||
[
|
||||
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
|
||||
for t in client.tools.list(toolgroup_id=toolgroup_id)
|
||||
]
|
||||
)
|
||||
|
||||
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
|
||||
st.json(active_tool_list)
|
||||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={
|
||||
"strategy": {"type": "greedy"},
|
||||
},
|
||||
)
|
||||
|
||||
agent = create_agent()
|
||||
|
||||
if "agent_session_id" not in st.session_state:
|
||||
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
|
||||
|
||||
session_id = st.session_state["agent_session_id"]
|
||||
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
||||
|
||||
for msg in st.session_state.messages:
|
||||
with st.chat_message(msg["role"]):
|
||||
st.markdown(msg["content"])
|
||||
|
||||
if prompt := st.chat_input(placeholder=""):
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
turn_response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def response_generator(turn_response):
|
||||
for response in turn_response:
|
||||
if hasattr(response.event, "payload"):
|
||||
print(response.event.payload)
|
||||
if response.event.payload.event_type == "step_progress":
|
||||
if hasattr(response.event.payload.delta, "text"):
|
||||
yield response.event.payload.delta.text
|
||||
if response.event.payload.event_type == "step_complete":
|
||||
if response.event.payload.step_details.step_type == "tool_execution":
|
||||
yield " 🛠 "
|
||||
else:
|
||||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
response = st.write_stream(response_generator(turn_response))
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": response})
|
||||
|
||||
|
||||
tool_chat_page()
|
|
@ -1,4 +1,5 @@
|
|||
streamlit
|
||||
pandas
|
||||
llama-stack-client>=0.0.55
|
||||
llama-stack-client>=0.2.1
|
||||
streamlit-option-menu
|
||||
llama-stack>=0.2.1
|
||||
|
|
|
@ -29,6 +29,11 @@ def preserve_contexts_async_generator(
|
|||
context_var.set(initial_context_values[context_var.name])
|
||||
|
||||
item = await gen.__anext__()
|
||||
|
||||
# Update our tracked values with any changes made during this iteration
|
||||
for context_var in context_vars:
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
|
||||
yield item
|
||||
|
||||
except StopAsyncIteration:
|
||||
|
|
|
@ -119,17 +119,16 @@ class Llama3:
|
|||
torch.set_default_device(device)
|
||||
else:
|
||||
print(f"Setting default device to {device}")
|
||||
torch.set_default_device(device)
|
||||
if device.type == "cuda":
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_dtype(torch.half)
|
||||
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
|
||||
elif device.type == "xpu":
|
||||
if torch.xpu.is_bf16_supported():
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_dtype(torch.half)
|
||||
torch.set_default_tensor_type(torch.xpu.Float16Tensor)
|
||||
|
||||
model = build_model()
|
||||
print("Loading state dict...")
|
||||
|
|
|
@ -70,6 +70,9 @@ class ModelArgs(BaseModel):
|
|||
attention_chunk_size: Optional[int] = None
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
rope_scaling_factor: Optional[float] = None
|
||||
rope_high_freq_factor: Optional[float] = None
|
||||
|
||||
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
|
||||
use_qk_norm: bool = False
|
||||
# Set to True to enable inference-time temperature tuning (useful for very long context)
|
||||
|
@ -92,4 +95,14 @@ class ModelArgs(BaseModel):
|
|||
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
|
||||
)
|
||||
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
|
||||
|
||||
if self.use_scaled_rope:
|
||||
# NOTE: ideally these values should have come from params.json. However, we have
|
||||
# shipped the models everywhere. Only Llama-4-Scout uses scaled rope and needs these
|
||||
# specific values.
|
||||
if self.rope_scaling_factor is None:
|
||||
self.rope_scaling_factor = 16
|
||||
if self.rope_high_freq_factor is None:
|
||||
self.rope_high_freq_factor = 1
|
||||
|
||||
return self
|
||||
|
|
|
@ -23,37 +23,25 @@ from .ffn import FeedForward
|
|||
from .moe import MoE
|
||||
|
||||
|
||||
def rmsnorm(x, eps):
|
||||
def _norm(y):
|
||||
return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps)
|
||||
|
||||
return _norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
return rmsnorm(x, self.eps) * self.weight
|
||||
|
||||
|
||||
class L2Norm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def apply_scaling(freqs: torch.Tensor):
|
||||
# Values obtained from grid search
|
||||
scale_factor = 8
|
||||
def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float):
|
||||
low_freq_factor = 1
|
||||
high_freq_factor = 4
|
||||
old_context_len = 8192 # original llama3 length
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
|
@ -72,11 +60,18 @@ def apply_scaling(freqs: torch.Tensor):
|
|||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||
def precompute_freqs_cis(
|
||||
dim: int,
|
||||
end: int,
|
||||
theta: float,
|
||||
use_scaled: bool,
|
||||
scale_factor: float,
|
||||
high_freq_factor: float,
|
||||
):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||
if use_scaled:
|
||||
freqs = apply_scaling(freqs)
|
||||
freqs = apply_scaling(freqs, scale_factor, high_freq_factor)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
@ -174,9 +169,7 @@ class Attention(nn.Module):
|
|||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
self.qk_norm = None
|
||||
if self.use_qk_norm:
|
||||
self.qk_norm = L2Norm(args.norm_eps)
|
||||
self.norm_eps = args.norm_eps
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
|
@ -220,8 +213,8 @@ class Attention(nn.Module):
|
|||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
if self.use_qk_norm:
|
||||
xq = self.qk_norm(xq)
|
||||
xk = self.qk_norm(xk)
|
||||
xq = rmsnorm(xq, self.norm_eps)
|
||||
xk = rmsnorm(xk, self.norm_eps)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
||||
# the inference-time temperature tuning function is customized to not affect short context
|
||||
|
@ -362,6 +355,8 @@ class Transformer(nn.Module):
|
|||
args.max_seq_len * 2,
|
||||
args.rope_theta,
|
||||
args.use_scaled_rope,
|
||||
args.rope_scaling_factor,
|
||||
args.rope_high_freq_factor,
|
||||
)
|
||||
vision_args = self.args.vision_args
|
||||
if vision_args:
|
||||
|
|
|
@ -91,7 +91,7 @@ def convert_to_quantized_model(
|
|||
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
|
||||
|
||||
def apply_quantization(_, weight):
|
||||
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
||||
return quantize_int4(weight, output_device=torch.device("cuda"))
|
||||
|
||||
else:
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
||||
|
|
|
@ -56,9 +56,11 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
|
|||
"<|text_post_train_reserved_special_token_3|>",
|
||||
"<|text_post_train_reserved_special_token_4|>",
|
||||
"<|text_post_train_reserved_special_token_5|>",
|
||||
"<|text_post_train_reserved_special_token_6|>",
|
||||
"<|text_post_train_reserved_special_token_7|>",
|
||||
"<|finetune_right_pad|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"text_post_train", 61, 6
|
||||
"text_post_train", 61, 8
|
||||
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
|
||||
|
||||
# 200080, ..., 201133
|
||||
|
|
|
@ -65,7 +65,7 @@ class Int4Weights(
|
|||
Int4ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Int4Weights",
|
||||
["weight", "scale", "zero_point", "shape", "activation_scale_ub"],
|
||||
["weight", "scale", "zero_point", "shape"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
@ -184,20 +184,13 @@ def quantize_fp8(
|
|||
@torch.inference_mode()
|
||||
def quantize_int4(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Int4Weights:
|
||||
"""Quantize [n, k/2] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
if w.ndim >= 3:
|
||||
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
||||
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
||||
|
@ -212,7 +205,6 @@ def quantize_int4(
|
|||
scale=scale.to(output_device),
|
||||
zero_point=zero_point.to(output_device),
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
|
@ -247,26 +239,18 @@ def load_int4(
|
|||
w: Tensor,
|
||||
scale: Tensor,
|
||||
zero_point: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Int4Weights:
|
||||
"""Load INT4 [n, k/2] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k/2] input INT4.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
return Int4Weights(
|
||||
weight=w.to(torch.int8).to(device=output_device),
|
||||
scale=scale.to(device=output_device),
|
||||
zero_point=zero_point.to(device=output_device),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -89,7 +89,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
agent_id: str,
|
||||
agent_config: AgentConfig,
|
||||
tempdir: str,
|
||||
inference_api: Inference,
|
||||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
|
@ -99,7 +98,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
self.tempdir = tempdir
|
||||
self.inference_api = inference_api
|
||||
self.safety_api = safety_api
|
||||
self.vector_io_api = vector_io_api
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
|
@ -64,7 +63,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
|
@ -107,7 +105,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
return ChatAgent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_config,
|
||||
tempdir=self.tempdir,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
vector_io_api=self.vector_io_api,
|
||||
|
|
|
@ -259,7 +259,7 @@ class Llama3Generator:
|
|||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content)],
|
||||
model_inputs=[self.formatter.encode_content(request.content)],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -284,7 +284,7 @@ class Llama3Generator:
|
|||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
|
|
@ -307,9 +307,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if model.model_type == ModelType.embedding:
|
||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
response = await self.client.list()
|
||||
else:
|
||||
response = await self.client.ps()
|
||||
# we use list() here instead of ps() -
|
||||
# - ps() only lists running models, not available models
|
||||
# - models not currently running are run by the ollama server as needed
|
||||
response = await self.client.list()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
|
|
|
@ -13,7 +13,7 @@ The `llamastack/distribution-{{ name }}` distribution consists of the following
|
|||
|
||||
{{ providers_table }}
|
||||
|
||||
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
|
||||
You can use this distribution if you want to run an independent vLLM server for inference.
|
||||
|
||||
{% if run_config_env_vars %}
|
||||
### Environment Variables
|
||||
|
@ -28,6 +28,83 @@ The following environment variables can be configured:
|
|||
|
||||
## Setting up vLLM server
|
||||
|
||||
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
|
||||
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||
that we only use GPUs here for demonstration purposes.
|
||||
|
||||
### Setting up vLLM server on AMD GPU
|
||||
|
||||
AMD provides two main vLLM container options:
|
||||
- rocm/vllm: Production-ready container
|
||||
- rocm/vllm-dev: Development container with the latest vLLM features
|
||||
|
||||
Please check the [Blog about ROCm vLLM Usage](https://rocm.blogs.amd.com/software-tools-optimization/vllm-container/README.html) to get more details.
|
||||
|
||||
Here is a sample script to start a ROCm vLLM server locally via Docker:
|
||||
|
||||
```bash
|
||||
export INFERENCE_PORT=8000
|
||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export VLLM_DIMG="rocm/vllm-dev:main"
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--ipc=host \
|
||||
--privileged \
|
||||
--shm-size 16g \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
--group-add video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--cap-add=CAP_SYS_ADMIN \
|
||||
--security-opt seccomp=unconfined \
|
||||
--security-opt apparmor=unconfined \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
|
||||
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
$VLLM_DIMG \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model $INFERENCE_MODEL \
|
||||
--port $INFERENCE_PORT
|
||||
```
|
||||
|
||||
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||
|
||||
```bash
|
||||
export SAFETY_PORT=8081
|
||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export VLLM_DIMG="rocm/vllm-dev:main"
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--ipc=host \
|
||||
--privileged \
|
||||
--shm-size 16g \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
--group-add video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--cap-add=CAP_SYS_ADMIN \
|
||||
--security-opt seccomp=unconfined \
|
||||
--security-opt apparmor=unconfined \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
|
||||
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
$VLLM_DIMG \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model $SAFETY_MODEL \
|
||||
--port $SAFETY_PORT
|
||||
```
|
||||
|
||||
### Setting up vLLM server on NVIDIA GPU
|
||||
|
||||
Please check the [vLLM Documentation](https://docs.vllm.ai/en/v0.5.5/serving/deploying_with_docker.html) to get a vLLM endpoint. Here is a sample script to start a vLLM server locally via Docker:
|
||||
|
||||
```bash
|
||||
|
|
|
@ -89,6 +89,12 @@ docs = [
|
|||
"tomli",
|
||||
]
|
||||
codegen = ["rich", "pydantic", "jinja2>=3.1.6"]
|
||||
ui = [
|
||||
"streamlit",
|
||||
"pandas",
|
||||
"llama-stack-client>=0.2.1",
|
||||
"streamlit-option-menu",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/meta-llama/llama-stack"
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
# Ollama external provider for Llama Stack
|
||||
|
||||
Template code to create a new external provider for Llama Stack.
|
|
@ -0,0 +1,7 @@
|
|||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages: ["ollama", "aiohttp"]
|
||||
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
||||
module: llama_stack_provider_ollama
|
||||
api_dependencies: []
|
||||
optional_api_dependencies: []
|
|
@ -0,0 +1,44 @@
|
|||
[project]
|
||||
dependencies = [
|
||||
"llama-stack",
|
||||
"pydantic",
|
||||
"ollama",
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
]
|
||||
|
||||
name = "llama-stack-provider-ollama"
|
||||
version = "0.1.0"
|
||||
description = "External provider for Ollama using the Llama Stack API"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
135
tests/external-provider/llama-stack-provider-ollama/run.yaml
Normal file
135
tests/external-provider/llama-stack-provider-ollama/run.yaml
Normal file
|
@ -0,0 +1,135 @@
|
|||
version: '2'
|
||||
image_name: ollama
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: custom_ollama
|
||||
provider_type: remote::custom_ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: custom_ollama
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: custom_ollama
|
||||
provider_model_id: all-minilm:latest
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
||||
external_providers_dir: /tmp/providers.d
|
124
tests/integration/tool_runtime/test_registration.py
Normal file
124
tests/integration/tool_runtime/test_registration.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import mcp.types as types
|
||||
import pytest
|
||||
import uvicorn
|
||||
from llama_stack_client.types.shared_params.url import URL
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_server():
|
||||
server = FastMCP("FastMCP Test Server")
|
||||
|
||||
@server.tool()
|
||||
async def fetch(url: str, ctx: Context) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
headers = {"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"}
|
||||
async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return [types.TextContent(type="text", text=response.text)]
|
||||
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request):
|
||||
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||
await server._mcp_server.run(
|
||||
streams[0],
|
||||
streams[1],
|
||||
server._mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
def get_open_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
port = get_open_port()
|
||||
|
||||
def run_server():
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
||||
# Start the server in a new thread
|
||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
# Polling until the server is ready
|
||||
timeout = 10
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = httpx.get(f"http://localhost:{port}/sse")
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except (httpx.RequestError, httpx.HTTPStatusError):
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
|
||||
yield port
|
||||
|
||||
|
||||
def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
|
||||
"""
|
||||
Integration test for registering and unregistering a toolgroup using the ToolGroups API.
|
||||
"""
|
||||
port = mcp_server
|
||||
test_toolgroup_id = "remote::web-fetch"
|
||||
provider_id = "model-context-protocol"
|
||||
|
||||
# Cleanup before running the test
|
||||
toolgroups = llama_stack_client.toolgroups.list()
|
||||
for toolgroup in toolgroups:
|
||||
if toolgroup.identifier == test_toolgroup_id:
|
||||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
|
||||
|
||||
# Register the toolgroup
|
||||
llama_stack_client.toolgroups.register(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
mcp_endpoint=URL(uri=f"http://localhost:{port}/sse"),
|
||||
)
|
||||
|
||||
# Verify registration
|
||||
registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||
assert registered_toolgroup is not None
|
||||
assert registered_toolgroup.identifier == test_toolgroup_id
|
||||
assert registered_toolgroup.provider_id == provider_id
|
||||
|
||||
# Verify tools listing
|
||||
tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
||||
assert isinstance(tools_list_response, list)
|
||||
assert tools_list_response
|
||||
|
||||
# Unregister the toolgroup
|
||||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
|
||||
|
||||
# Verify it is unregistered
|
||||
with pytest.raises(ValueError, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||
|
||||
# Verify tools are also unregistered
|
||||
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
||||
assert isinstance(unregister_tools_list_response, list)
|
||||
assert not unregister_tools_list_response
|
223
tests/unit/distribution/test_distribution.py
Normal file
223
tests/unit/distribution/test_distribution.py
Normal file
|
@ -0,0 +1,223 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
foo: str = Field(
|
||||
default="bar",
|
||||
description="foo",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"foo": "baz",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_providers():
|
||||
"""Mock the available_providers function to return test providers."""
|
||||
with patch("llama_stack.providers.registry.inference.available_providers") as mock:
|
||||
mock.return_value = [
|
||||
ProviderSpec(
|
||||
provider_type="test_provider",
|
||||
api=Api.inference,
|
||||
adapter_type="test_adapter",
|
||||
config_class="test_provider.config.TestProviderConfig",
|
||||
)
|
||||
]
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config(tmp_path):
|
||||
"""Create a base StackRunConfig with common settings."""
|
||||
return StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="sample_provider",
|
||||
provider_type="sample",
|
||||
config=SampleConfig.sample_run_config(),
|
||||
)
|
||||
]
|
||||
},
|
||||
external_providers_dir=str(tmp_path),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_spec_yaml():
|
||||
"""Common provider spec YAML for testing."""
|
||||
return """
|
||||
adapter:
|
||||
adapter_type: test_provider
|
||||
config_class: test_provider.config.TestProviderConfig
|
||||
module: test_provider
|
||||
api_dependencies:
|
||||
- safety
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inline_provider_spec_yaml():
|
||||
"""Common inline provider spec YAML for testing."""
|
||||
return """
|
||||
module: test_provider
|
||||
config_class: test_provider.config.TestProviderConfig
|
||||
pip_packages:
|
||||
- test-package
|
||||
api_dependencies:
|
||||
- safety
|
||||
optional_api_dependencies:
|
||||
- vector_io
|
||||
provider_data_validator: test_provider.validator.TestValidator
|
||||
container_image: test-image:latest
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_directories(tmp_path):
|
||||
"""Create the API directory structure for testing."""
|
||||
# Create remote provider directory
|
||||
remote_inference_dir = tmp_path / "remote" / "inference"
|
||||
remote_inference_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create inline provider directory
|
||||
inline_inference_dir = tmp_path / "inline" / "inference"
|
||||
inline_inference_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return remote_inference_dir, inline_inference_dir
|
||||
|
||||
|
||||
class TestProviderRegistry:
|
||||
"""Test suite for provider registry functionality."""
|
||||
|
||||
def test_builtin_providers(self, mock_providers):
|
||||
"""Test loading built-in providers."""
|
||||
registry = get_provider_registry(None)
|
||||
|
||||
assert Api.inference in registry
|
||||
assert "test_provider" in registry[Api.inference]
|
||||
assert registry[Api.inference]["test_provider"].provider_type == "test_provider"
|
||||
assert registry[Api.inference]["test_provider"].api == Api.inference
|
||||
|
||||
def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml):
|
||||
"""Test loading external remote providers from YAML files."""
|
||||
remote_dir, _ = api_directories
|
||||
with open(remote_dir / "test_provider.yaml", "w") as f:
|
||||
f.write(provider_spec_yaml)
|
||||
|
||||
registry = get_provider_registry(base_config)
|
||||
assert len(registry[Api.inference]) == 2
|
||||
|
||||
assert Api.inference in registry
|
||||
assert "remote::test_provider" in registry[Api.inference]
|
||||
provider = registry[Api.inference]["remote::test_provider"]
|
||||
assert provider.adapter.adapter_type == "test_provider"
|
||||
assert provider.adapter.module == "test_provider"
|
||||
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
|
||||
assert Api.safety in provider.api_dependencies
|
||||
|
||||
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
||||
"""Test loading external inline providers from YAML files."""
|
||||
_, inline_dir = api_directories
|
||||
with open(inline_dir / "test_provider.yaml", "w") as f:
|
||||
f.write(inline_provider_spec_yaml)
|
||||
|
||||
registry = get_provider_registry(base_config)
|
||||
assert len(registry[Api.inference]) == 2
|
||||
|
||||
assert Api.inference in registry
|
||||
assert "inline::test_provider" in registry[Api.inference]
|
||||
provider = registry[Api.inference]["inline::test_provider"]
|
||||
assert provider.provider_type == "inline::test_provider"
|
||||
assert provider.module == "test_provider"
|
||||
assert provider.config_class == "test_provider.config.TestProviderConfig"
|
||||
assert provider.pip_packages == ["test-package"]
|
||||
assert Api.safety in provider.api_dependencies
|
||||
assert Api.vector_io in provider.optional_api_dependencies
|
||||
assert provider.provider_data_validator == "test_provider.validator.TestValidator"
|
||||
assert provider.container_image == "test-image:latest"
|
||||
|
||||
def test_invalid_yaml(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of invalid YAML files."""
|
||||
remote_dir, inline_dir = api_directories
|
||||
with open(remote_dir / "invalid.yaml", "w") as f:
|
||||
f.write("invalid: yaml: content: -")
|
||||
with open(inline_dir / "invalid.yaml", "w") as f:
|
||||
f.write("invalid: yaml: content: -")
|
||||
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
get_provider_registry(base_config)
|
||||
|
||||
def test_missing_directory(self, mock_providers):
|
||||
"""Test handling of missing external providers directory."""
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="sample_provider",
|
||||
provider_type="sample",
|
||||
config=SampleConfig.sample_run_config(),
|
||||
)
|
||||
]
|
||||
},
|
||||
external_providers_dir="/nonexistent/dir",
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
get_provider_registry(config)
|
||||
|
||||
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of empty API directory."""
|
||||
registry = get_provider_registry(base_config)
|
||||
assert len(registry[Api.inference]) == 1 # Only built-in provider
|
||||
|
||||
def test_malformed_remote_provider_spec(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of malformed remote provider spec (missing required fields)."""
|
||||
remote_dir, _ = api_directories
|
||||
malformed_spec = """
|
||||
adapter:
|
||||
adapter_type: test_provider
|
||||
# Missing required fields
|
||||
api_dependencies:
|
||||
- safety
|
||||
"""
|
||||
with open(remote_dir / "malformed.yaml", "w") as f:
|
||||
f.write(malformed_spec)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
get_provider_registry(base_config)
|
||||
|
||||
def test_malformed_inline_provider_spec(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of malformed inline provider spec (missing required fields)."""
|
||||
_, inline_dir = api_directories
|
||||
malformed_spec = """
|
||||
module: test_provider
|
||||
# Missing required config_class
|
||||
pip_packages:
|
||||
- test-package
|
||||
"""
|
||||
with open(inline_dir / "malformed.yaml", "w") as f:
|
||||
f.write(malformed_spec)
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
get_provider_registry(base_config)
|
||||
assert "config_class" in str(exc_info.value)
|
65
tests/verifications/README.md
Normal file
65
tests/verifications/README.md
Normal file
|
@ -0,0 +1,65 @@
|
|||
# Llama Stack Verifications
|
||||
|
||||
Llama Stack Verifications provide standardized test suites to ensure API compatibility and behavior consistency across different LLM providers. These tests help verify that different models and providers implement the expected interfaces and behaviors correctly.
|
||||
|
||||
## Overview
|
||||
|
||||
This framework allows you to run the same set of verification tests against different LLM providers' OpenAI-compatible endpoints (Fireworks, Together, Groq, Cerebras, etc., and OpenAI itself) to ensure they meet the expected behavior and interface standards.
|
||||
|
||||
## Features
|
||||
|
||||
The verification suite currently tests:
|
||||
|
||||
- Basic chat completions (streaming and non-streaming)
|
||||
- Image input capabilities
|
||||
- Structured JSON output formatting
|
||||
- Tool calling functionality
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the verification tests, use pytest with the following parameters:
|
||||
|
||||
```bash
|
||||
cd llama-stack
|
||||
pytest tests/verifications/openai --provider=<provider-name>
|
||||
```
|
||||
|
||||
Example:
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest tests/verifications/openai --provider=together
|
||||
|
||||
# Only run tests with Llama 4 models
|
||||
pytest tests/verifications/openai --provider=together -k 'Llama-4'
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
- `--provider`: The provider name (openai, fireworks, together, groq, cerebras, etc.)
|
||||
- `--base-url`: The base URL for the provider's API (optional - defaults to the standard URL for the specified provider)
|
||||
- `--api-key`: Your API key for the provider (optional - defaults to the standard API_KEY name for the specified provider)
|
||||
|
||||
## Supported Providers
|
||||
|
||||
The verification suite currently supports:
|
||||
- OpenAI
|
||||
- Fireworks
|
||||
- Together
|
||||
- Groq
|
||||
- Cerebras
|
||||
|
||||
## Adding New Test Cases
|
||||
|
||||
To add new test cases, create appropriate JSON files in the `openai/fixtures/test_cases/` directory following the existing patterns.
|
||||
|
||||
|
||||
## Structure
|
||||
|
||||
- `__init__.py` - Marks the directory as a Python package
|
||||
- `conftest.py` - Global pytest configuration and fixtures
|
||||
- `openai/` - Tests specific to OpenAI-compatible APIs
|
||||
- `fixtures/` - Test fixtures and utilities
|
||||
- `fixtures.py` - Provider-specific fixtures
|
||||
- `load.py` - Utilities for loading test cases
|
||||
- `test_cases/` - JSON test case definitions
|
||||
- `test_chat_completion.py` - Tests for chat completion APIs
|
88
tests/verifications/REPORT.md
Normal file
88
tests/verifications/REPORT.md
Normal file
|
@ -0,0 +1,88 @@
|
|||
# Test Results Report
|
||||
|
||||
*Generated on: 2025-04-08 21:14:02*
|
||||
|
||||
*This report was generated by running `python tests/verifications/generate_report.py`*
|
||||
|
||||
## Legend
|
||||
|
||||
- ✅ - Test passed
|
||||
- ❌ - Test failed
|
||||
- ⚪ - Test not applicable or not run for this model
|
||||
|
||||
|
||||
## Summary
|
||||
|
||||
| Provider | Pass Rate | Tests Passed | Total Tests |
|
||||
| --- | --- | --- | --- |
|
||||
| Together | 67.7% | 21 | 31 |
|
||||
| Fireworks | 90.3% | 28 | 31 |
|
||||
| Openai | 100.0% | 22 | 22 |
|
||||
|
||||
|
||||
|
||||
## Together
|
||||
|
||||
*Tests run on: 2025-04-08 16:19:59*
|
||||
|
||||
```bash
|
||||
pytest tests/verifications/openai/test_chat_completion.py --provider=together -v
|
||||
```
|
||||
|
||||
| Test | Llama-3.3-70B-Instruct | Llama-4-Maverick-17B-128E-Instruct | Llama-4-Scout-17B-16E-Instruct |
|
||||
| --- | --- | --- | --- |
|
||||
| test_chat_non_streaming_basic (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (case 1) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image (case 0) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 1) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (case 0) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_basic (case 1) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_image (case 0) | ⚪ | ❌ | ❌ |
|
||||
| test_chat_streaming_structured_output (case 0) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_structured_output (case 1) | ✅ | ❌ | ❌ |
|
||||
|
||||
## Fireworks
|
||||
|
||||
*Tests run on: 2025-04-08 16:18:28*
|
||||
|
||||
```bash
|
||||
pytest tests/verifications/openai/test_chat_completion.py --provider=fireworks -v
|
||||
```
|
||||
|
||||
| Test | Llama-3.3-70B-Instruct | Llama-4-Maverick-17B-128E-Instruct | Llama-4-Scout-17B-16E-Instruct |
|
||||
| --- | --- | --- | --- |
|
||||
| test_chat_non_streaming_basic (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (case 1) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image (case 0) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 1) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling (case 0) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_basic (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (case 1) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_image (case 0) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (case 0) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (case 1) | ❌ | ✅ | ✅ |
|
||||
|
||||
## Openai
|
||||
|
||||
*Tests run on: 2025-04-08 16:22:02*
|
||||
|
||||
```bash
|
||||
pytest tests/verifications/openai/test_chat_completion.py --provider=openai -v
|
||||
```
|
||||
|
||||
| Test | gpt-4o | gpt-4o-mini |
|
||||
| --- | --- | --- |
|
||||
| test_chat_non_streaming_basic (case 0) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (case 1) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image (case 0) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 0) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (case 1) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling (case 0) | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (case 0) | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (case 1) | ✅ | ✅ |
|
||||
| test_chat_streaming_image (case 0) | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (case 0) | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (case 1) | ✅ | ✅ |
|
5
tests/verifications/__init__.py
Normal file
5
tests/verifications/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
28
tests/verifications/conftest.py
Normal file
28
tests/verifications/conftest.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--base-url",
|
||||
action="store",
|
||||
help="Base URL for OpenAI compatible API",
|
||||
)
|
||||
parser.addoption(
|
||||
"--api-key",
|
||||
action="store",
|
||||
help="API key",
|
||||
)
|
||||
parser.addoption(
|
||||
"--provider",
|
||||
action="store",
|
||||
help="Provider to use for testing",
|
||||
)
|
||||
|
||||
|
||||
pytest_plugins = [
|
||||
"tests.verifications.openai.fixtures.fixtures",
|
||||
]
|
485
tests/verifications/generate_report.py
Executable file
485
tests/verifications/generate_report.py
Executable file
|
@ -0,0 +1,485 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test Report Generator
|
||||
|
||||
Requirements:
|
||||
pip install pytest-json-report
|
||||
|
||||
Usage:
|
||||
# Generate a report using existing test results
|
||||
python tests/verifications/generate_report.py
|
||||
|
||||
# Run tests and generate a report
|
||||
python tests/verifications/generate_report.py --run-tests
|
||||
|
||||
# Run tests for specific providers
|
||||
python tests/verifications/generate_report.py --run-tests --providers fireworks openai
|
||||
|
||||
# Save the report to a custom location
|
||||
python tests/verifications/generate_report.py --output custom_report.md
|
||||
|
||||
# Clean up old test result files
|
||||
python tests/verifications/generate_report.py --cleanup
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
# Define the root directory for test results
|
||||
RESULTS_DIR = Path(__file__).parent / "test_results"
|
||||
RESULTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Maximum number of test result files to keep per provider
|
||||
MAX_RESULTS_PER_PROVIDER = 1
|
||||
|
||||
# Custom order of providers
|
||||
PROVIDER_ORDER = ["together", "fireworks", "groq", "cerebras", "openai"]
|
||||
|
||||
# Dictionary to store providers and their models (will be populated dynamically)
|
||||
PROVIDERS = defaultdict(set)
|
||||
|
||||
# Tests will be dynamically extracted from results
|
||||
ALL_TESTS = set()
|
||||
|
||||
|
||||
def run_tests(provider):
|
||||
"""Run pytest for a specific provider and save results"""
|
||||
print(f"Running tests for provider: {provider}")
|
||||
|
||||
timestamp = int(time.time())
|
||||
result_file = RESULTS_DIR / f"{provider}_{timestamp}.json"
|
||||
temp_json_file = RESULTS_DIR / f"temp_{provider}_{timestamp}.json"
|
||||
|
||||
# Run pytest with JSON output
|
||||
cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
"tests/verifications/openai/test_chat_completion.py",
|
||||
f"--provider={provider}",
|
||||
"-v",
|
||||
"--json-report",
|
||||
f"--json-report-file={temp_json_file}",
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
print(f"Pytest exit code: {result.returncode}")
|
||||
|
||||
# Check if the JSON file was created
|
||||
if temp_json_file.exists():
|
||||
# Read the JSON file and save it to our results format
|
||||
with open(temp_json_file, "r") as f:
|
||||
test_results = json.load(f)
|
||||
|
||||
# Save results to our own format with a trailing newline
|
||||
with open(result_file, "w") as f:
|
||||
json.dump(test_results, f, indent=2)
|
||||
f.write("\n") # Add a trailing newline for precommit
|
||||
|
||||
# Clean up temp file
|
||||
temp_json_file.unlink()
|
||||
|
||||
print(f"Test results saved to {result_file}")
|
||||
return result_file
|
||||
else:
|
||||
print(f"Error: JSON report file not created for {provider}")
|
||||
print(f"Command stdout: {result.stdout}")
|
||||
print(f"Command stderr: {result.stderr}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error running tests for {provider}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_results(result_file):
|
||||
"""Parse the test results file and extract pass/fail by model and test"""
|
||||
if not os.path.exists(result_file):
|
||||
print(f"Results file does not exist: {result_file}")
|
||||
return {}
|
||||
|
||||
with open(result_file, "r") as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Initialize results dictionary
|
||||
parsed_results = defaultdict(lambda: defaultdict(dict))
|
||||
provider = os.path.basename(result_file).split("_")[0]
|
||||
|
||||
# Debug: Print summary of test results
|
||||
print(f"Test results summary for {provider}:")
|
||||
print(f"Total tests: {results.get('summary', {}).get('total', 0)}")
|
||||
print(f"Passed: {results.get('summary', {}).get('passed', 0)}")
|
||||
print(f"Failed: {results.get('summary', {}).get('failed', 0)}")
|
||||
print(f"Error: {results.get('summary', {}).get('error', 0)}")
|
||||
print(f"Skipped: {results.get('summary', {}).get('skipped', 0)}")
|
||||
|
||||
# Extract test results
|
||||
if "tests" not in results or not results["tests"]:
|
||||
print(f"No test results found in {result_file}")
|
||||
return parsed_results
|
||||
|
||||
# Map for normalizing model names
|
||||
model_name_map = {
|
||||
"Llama-3.3-8B-Instruct": "Llama-3.3-8B-Instruct",
|
||||
"Llama-3.3-70B-Instruct": "Llama-3.3-70B-Instruct",
|
||||
"Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct",
|
||||
"Llama-4-Scout-17B-16E": "Llama-4-Scout-17B-16E-Instruct",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "Llama-4-Scout-17B-16E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E": "Llama-4-Maverick-17B-128E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "Llama-4-Maverick-17B-128E-Instruct",
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4o-mini": "gpt-4o-mini",
|
||||
}
|
||||
|
||||
# Keep track of all models found for this provider
|
||||
provider_models = set()
|
||||
|
||||
# Track all unique test cases for each base test
|
||||
test_case_counts = defaultdict(int)
|
||||
|
||||
# First pass: count the number of cases for each test
|
||||
for test in results["tests"]:
|
||||
test_id = test.get("nodeid", "")
|
||||
|
||||
if "call" in test:
|
||||
test_name = test_id.split("::")[1].split("[")[0]
|
||||
input_output_match = re.search(r"\[input_output(\d+)-", test_id)
|
||||
if input_output_match:
|
||||
test_case_counts[test_name] += 1
|
||||
|
||||
# Second pass: process the tests with case numbers only for tests with multiple cases
|
||||
for test in results["tests"]:
|
||||
test_id = test.get("nodeid", "")
|
||||
outcome = test.get("outcome", "")
|
||||
|
||||
# Only process tests that have been executed (not setup errors)
|
||||
if "call" in test:
|
||||
# Regular test that actually ran
|
||||
test_name = test_id.split("::")[1].split("[")[0]
|
||||
|
||||
# Extract input_output parameter to differentiate between test cases
|
||||
input_output_match = re.search(r"\[input_output(\d+)-", test_id)
|
||||
input_output_index = input_output_match.group(1) if input_output_match else ""
|
||||
|
||||
# Create a more detailed test name with case number only if there are multiple cases
|
||||
detailed_test_name = test_name
|
||||
if input_output_index and test_case_counts[test_name] > 1:
|
||||
detailed_test_name = f"{test_name} (case {input_output_index})"
|
||||
|
||||
# Track all unique test names
|
||||
ALL_TESTS.add(detailed_test_name)
|
||||
|
||||
# Extract model name from test_id using a more robust pattern
|
||||
model_match = re.search(r"\[input_output\d+-([^\]]+)\]", test_id)
|
||||
if model_match:
|
||||
raw_model = model_match.group(1)
|
||||
model = model_name_map.get(raw_model, raw_model)
|
||||
|
||||
# Add to set of known models for this provider
|
||||
provider_models.add(model)
|
||||
|
||||
# Also update the global PROVIDERS dictionary
|
||||
PROVIDERS[provider].add(model)
|
||||
|
||||
# Store the result
|
||||
if outcome == "passed":
|
||||
parsed_results[provider][model][detailed_test_name] = True
|
||||
else:
|
||||
parsed_results[provider][model][detailed_test_name] = False
|
||||
|
||||
print(f"Parsed test result: {detailed_test_name} for model {model}: {outcome}")
|
||||
elif outcome == "error" and "setup" in test and test.get("setup", {}).get("outcome") == "failed":
|
||||
# This is a setup failure, which likely means a configuration issue
|
||||
# Extract the base test name and model name
|
||||
parts = test_id.split("::")
|
||||
if len(parts) > 1:
|
||||
test_name = parts[1].split("[")[0]
|
||||
|
||||
# Extract input_output parameter to differentiate between test cases
|
||||
input_output_match = re.search(r"\[input_output(\d+)-", test_id)
|
||||
input_output_index = input_output_match.group(1) if input_output_match else ""
|
||||
|
||||
# Create a more detailed test name with case number only if there are multiple cases
|
||||
detailed_test_name = test_name
|
||||
if input_output_index and test_case_counts[test_name] > 1:
|
||||
detailed_test_name = f"{test_name} (case {input_output_index})"
|
||||
|
||||
if detailed_test_name in ALL_TESTS:
|
||||
# Use a more robust pattern for model extraction
|
||||
model_match = re.search(r"\[input_output\d+-([^\]]+)\]", test_id)
|
||||
if model_match:
|
||||
raw_model = model_match.group(1)
|
||||
model = model_name_map.get(raw_model, raw_model)
|
||||
|
||||
# Add to set of known models for this provider
|
||||
provider_models.add(model)
|
||||
|
||||
# Also update the global PROVIDERS dictionary
|
||||
PROVIDERS[provider].add(model)
|
||||
|
||||
# Mark setup failures as false (failed)
|
||||
parsed_results[provider][model][detailed_test_name] = False
|
||||
print(f"Parsed setup failure: {detailed_test_name} for model {model}")
|
||||
|
||||
# Debug: Print parsed results
|
||||
if not parsed_results[provider]:
|
||||
print(f"Warning: No test results parsed for provider {provider}")
|
||||
else:
|
||||
for model, tests in parsed_results[provider].items():
|
||||
print(f"Model {model}: {len(tests)} test results")
|
||||
|
||||
return parsed_results
|
||||
|
||||
|
||||
def cleanup_old_results():
|
||||
"""Clean up old test result files, keeping only the newest N per provider"""
|
||||
for provider in PROVIDERS.keys():
|
||||
# Get all result files for this provider
|
||||
provider_files = list(RESULTS_DIR.glob(f"{provider}_*.json"))
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
provider_files.sort(key=lambda x: int(x.stem.split("_")[1]), reverse=True)
|
||||
|
||||
# Remove old files beyond the max to keep
|
||||
if len(provider_files) > MAX_RESULTS_PER_PROVIDER:
|
||||
for old_file in provider_files[MAX_RESULTS_PER_PROVIDER:]:
|
||||
try:
|
||||
old_file.unlink()
|
||||
print(f"Removed old result file: {old_file}")
|
||||
except Exception as e:
|
||||
print(f"Error removing file {old_file}: {e}")
|
||||
|
||||
|
||||
def get_latest_results_by_provider():
|
||||
"""Get the latest test result file for each provider"""
|
||||
provider_results = {}
|
||||
|
||||
# Get all result files
|
||||
result_files = list(RESULTS_DIR.glob("*.json"))
|
||||
|
||||
# Extract all provider names from filenames
|
||||
all_providers = set()
|
||||
for file in result_files:
|
||||
# File format is provider_timestamp.json
|
||||
parts = file.stem.split("_")
|
||||
if len(parts) >= 2:
|
||||
all_providers.add(parts[0])
|
||||
|
||||
# Group by provider
|
||||
for provider in all_providers:
|
||||
provider_files = [f for f in result_files if f.name.startswith(f"{provider}_")]
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
provider_files.sort(key=lambda x: int(x.stem.split("_")[1]), reverse=True)
|
||||
|
||||
if provider_files:
|
||||
provider_results[provider] = provider_files[0]
|
||||
|
||||
return provider_results
|
||||
|
||||
|
||||
def generate_report(results_dict, output_file=None):
|
||||
"""Generate the markdown report"""
|
||||
if output_file is None:
|
||||
# Default to creating the report in the same directory as this script
|
||||
output_file = Path(__file__).parent / "REPORT.md"
|
||||
else:
|
||||
output_file = Path(output_file)
|
||||
|
||||
# Get the timestamp from result files
|
||||
provider_timestamps = {}
|
||||
provider_results = get_latest_results_by_provider()
|
||||
for provider, result_file in provider_results.items():
|
||||
# Extract timestamp from filename (format: provider_timestamp.json)
|
||||
try:
|
||||
timestamp_str = result_file.stem.split("_")[1]
|
||||
timestamp = int(timestamp_str)
|
||||
formatted_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
|
||||
provider_timestamps[provider] = formatted_time
|
||||
except (IndexError, ValueError):
|
||||
provider_timestamps[provider] = "Unknown"
|
||||
|
||||
# Convert provider model sets to sorted lists
|
||||
for provider in PROVIDERS:
|
||||
PROVIDERS[provider] = sorted(PROVIDERS[provider])
|
||||
|
||||
# Sort tests alphabetically
|
||||
sorted_tests = sorted(ALL_TESTS)
|
||||
|
||||
report = ["# Test Results Report\n"]
|
||||
report.append(f"*Generated on: {time.strftime('%Y-%m-%d %H:%M:%S')}*\n")
|
||||
report.append("*This report was generated by running `python tests/verifications/generate_report.py`*\n")
|
||||
|
||||
# Icons for pass/fail
|
||||
pass_icon = "✅"
|
||||
fail_icon = "❌"
|
||||
na_icon = "⚪"
|
||||
|
||||
# Add emoji legend
|
||||
report.append("## Legend\n")
|
||||
report.append(f"- {pass_icon} - Test passed")
|
||||
report.append(f"- {fail_icon} - Test failed")
|
||||
report.append(f"- {na_icon} - Test not applicable or not run for this model")
|
||||
report.append("\n")
|
||||
|
||||
# Add a summary section
|
||||
report.append("## Summary\n")
|
||||
|
||||
# Count total tests and passes
|
||||
total_tests = 0
|
||||
passed_tests = 0
|
||||
provider_totals = {}
|
||||
|
||||
# Prepare summary data
|
||||
for provider in PROVIDERS.keys():
|
||||
provider_passed = 0
|
||||
provider_total = 0
|
||||
|
||||
if provider in results_dict:
|
||||
provider_models = PROVIDERS[provider]
|
||||
for model in provider_models:
|
||||
if model in results_dict[provider]:
|
||||
model_results = results_dict[provider][model]
|
||||
for test in sorted_tests:
|
||||
if test in model_results:
|
||||
provider_total += 1
|
||||
total_tests += 1
|
||||
if model_results[test]:
|
||||
provider_passed += 1
|
||||
passed_tests += 1
|
||||
|
||||
provider_totals[provider] = (provider_passed, provider_total)
|
||||
|
||||
# Add summary table
|
||||
report.append("| Provider | Pass Rate | Tests Passed | Total Tests |")
|
||||
report.append("| --- | --- | --- | --- |")
|
||||
|
||||
# Use the custom order for summary table
|
||||
for provider in [p for p in PROVIDER_ORDER if p in PROVIDERS]:
|
||||
passed, total = provider_totals.get(provider, (0, 0))
|
||||
pass_rate = f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
report.append(f"| {provider.capitalize()} | {pass_rate} | {passed} | {total} |")
|
||||
|
||||
# Add providers not in the custom order
|
||||
for provider in [p for p in PROVIDERS if p not in PROVIDER_ORDER]:
|
||||
passed, total = provider_totals.get(provider, (0, 0))
|
||||
pass_rate = f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
report.append(f"| {provider.capitalize()} | {pass_rate} | {passed} | {total} |")
|
||||
|
||||
report.append("\n")
|
||||
|
||||
# Process each provider in the custom order, then any additional providers
|
||||
for provider in sorted(
|
||||
PROVIDERS.keys(), key=lambda p: (PROVIDER_ORDER.index(p) if p in PROVIDER_ORDER else float("inf"), p)
|
||||
):
|
||||
if not PROVIDERS[provider]:
|
||||
# Skip providers with no models
|
||||
continue
|
||||
|
||||
report.append(f"\n## {provider.capitalize()}\n")
|
||||
|
||||
# Add timestamp when test was run
|
||||
if provider in provider_timestamps:
|
||||
report.append(f"*Tests run on: {provider_timestamps[provider]}*\n")
|
||||
|
||||
# Add test command for reproducing results
|
||||
test_cmd = f"pytest tests/verifications/openai/test_chat_completion.py --provider={provider} -v"
|
||||
report.append(f"```bash\n{test_cmd}\n```\n")
|
||||
|
||||
# Get the relevant models for this provider
|
||||
provider_models = PROVIDERS[provider]
|
||||
|
||||
# Create table header with models as columns
|
||||
header = "| Test | " + " | ".join(provider_models) + " |"
|
||||
separator = "| --- | " + " | ".join(["---"] * len(provider_models)) + " |"
|
||||
|
||||
report.append(header)
|
||||
report.append(separator)
|
||||
|
||||
# Get results for this provider
|
||||
provider_results = results_dict.get(provider, {})
|
||||
|
||||
# Add rows for each test
|
||||
for test in sorted_tests:
|
||||
row = f"| {test} |"
|
||||
|
||||
# Add results for each model in this test
|
||||
for model in provider_models:
|
||||
if model in provider_results and test in provider_results[model]:
|
||||
result = pass_icon if provider_results[model][test] else fail_icon
|
||||
else:
|
||||
result = na_icon
|
||||
row += f" {result} |"
|
||||
|
||||
report.append(row)
|
||||
|
||||
# Write to file
|
||||
with open(output_file, "w") as f:
|
||||
f.write("\n".join(report))
|
||||
f.write("\n")
|
||||
|
||||
print(f"Report generated: {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate test report")
|
||||
parser.add_argument("--run-tests", action="store_true", help="Run tests before generating report")
|
||||
parser.add_argument(
|
||||
"--providers",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Specify providers to test (comma-separated or space-separated, default: all)",
|
||||
)
|
||||
parser.add_argument("--output", type=str, help="Output file location (default: tests/verifications/REPORT.md)")
|
||||
args = parser.parse_args()
|
||||
|
||||
all_results = {}
|
||||
|
||||
if args.run_tests:
|
||||
# Get list of available providers from command line or use detected providers
|
||||
if args.providers:
|
||||
# Handle both comma-separated and space-separated lists
|
||||
test_providers = []
|
||||
for provider_arg in args.providers:
|
||||
# Split by comma if commas are present
|
||||
if "," in provider_arg:
|
||||
test_providers.extend(provider_arg.split(","))
|
||||
else:
|
||||
test_providers.append(provider_arg)
|
||||
else:
|
||||
# Default providers to test
|
||||
test_providers = PROVIDER_ORDER
|
||||
|
||||
for provider in test_providers:
|
||||
provider = provider.strip() # Remove any whitespace
|
||||
result_file = run_tests(provider)
|
||||
if result_file:
|
||||
provider_results = parse_results(result_file)
|
||||
all_results.update(provider_results)
|
||||
else:
|
||||
# Use existing results
|
||||
provider_result_files = get_latest_results_by_provider()
|
||||
|
||||
for result_file in provider_result_files.values():
|
||||
provider_results = parse_results(result_file)
|
||||
all_results.update(provider_results)
|
||||
|
||||
# Generate the report
|
||||
generate_report(all_results, args.output)
|
||||
|
||||
cleanup_old_results()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
5
tests/verifications/openai/__init__.py
Normal file
5
tests/verifications/openai/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
5
tests/verifications/openai/fixtures/__init__.py
Normal file
5
tests/verifications/openai/fixtures/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
97
tests/verifications/openai/fixtures/fixtures.py
Normal file
97
tests/verifications/openai/fixtures/fixtures.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def providers_model_mapping():
|
||||
"""
|
||||
Mapping from model names used in test cases to provider's model names.
|
||||
"""
|
||||
return {
|
||||
"fireworks": {
|
||||
"Llama-3.3-70B-Instruct": "accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
"Llama-3.2-11B-Vision-Instruct": "accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "accounts/fireworks/models/llama4-scout-instruct-basic",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "accounts/fireworks/models/llama4-maverick-instruct-basic",
|
||||
},
|
||||
"together": {
|
||||
"Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
"Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
},
|
||||
"groq": {
|
||||
"Llama-3.3-70B-Instruct": "llama-3.3-70b-versatile",
|
||||
"Llama-3.2-11B-Vision-Instruct": "llama-3.2-11b-vision-preview",
|
||||
"Llama-4-Scout-17B-16E-Instruct": "llama-4-scout-17b-16e-instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct": "llama-4-maverick-17b-128e-instruct",
|
||||
},
|
||||
"cerebras": {
|
||||
"Llama-3.3-70B-Instruct": "llama-3.3-70b",
|
||||
},
|
||||
"openai": {
|
||||
"gpt-4o": "gpt-4o",
|
||||
"gpt-4o-mini": "gpt-4o-mini",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_metadata():
|
||||
return {
|
||||
"fireworks": ("https://api.fireworks.ai/inference/v1", "FIREWORKS_API_KEY"),
|
||||
"together": ("https://api.together.xyz/v1", "TOGETHER_API_KEY"),
|
||||
"groq": ("https://api.groq.com/openai/v1", "GROQ_API_KEY"),
|
||||
"cerebras": ("https://api.cerebras.ai/v1", "CEREBRAS_API_KEY"),
|
||||
"openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(request, provider_metadata):
|
||||
provider = request.config.getoption("--provider")
|
||||
base_url = request.config.getoption("--base-url")
|
||||
|
||||
if provider and base_url and provider_metadata[provider][0] != base_url:
|
||||
raise ValueError(f"Provider {provider} is not supported for base URL {base_url}")
|
||||
|
||||
if not provider:
|
||||
if not base_url:
|
||||
raise ValueError("Provider and base URL are not provided")
|
||||
for provider, metadata in provider_metadata.items():
|
||||
if metadata[0] == base_url:
|
||||
provider = provider
|
||||
break
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_url(request, provider, provider_metadata):
|
||||
return request.config.getoption("--base-url") or provider_metadata[provider][0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key(request, provider, provider_metadata):
|
||||
return request.config.getoption("--api-key") or os.getenv(provider_metadata[provider][1])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_mapping(provider, providers_model_mapping):
|
||||
return providers_model_mapping[provider]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(base_url, api_key):
|
||||
return OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
16
tests/verifications/openai/fixtures/load.py
Normal file
16
tests/verifications/openai/fixtures/load.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_test_cases(name: str):
|
||||
fixture_dir = Path(__file__).parent / "test_cases"
|
||||
yaml_path = fixture_dir / f"{name}.yaml"
|
||||
with open(yaml_path, "r") as f:
|
||||
return yaml.safe_load(f)
|
|
@ -0,0 +1,162 @@
|
|||
test_chat_basic:
|
||||
test_name: test_chat_basic
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
output: Earth
|
||||
- input:
|
||||
messages:
|
||||
- content: Which planet has rings around it with a name starting with letter
|
||||
S?
|
||||
role: user
|
||||
output: Saturn
|
||||
model:
|
||||
- Llama-3.3-8B-Instruct
|
||||
- Llama-3.3-70B-Instruct
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
test_chat_image:
|
||||
test_name: test_chat_image
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content:
|
||||
- text: What is in this image?
|
||||
type: text
|
||||
- image_url:
|
||||
url: https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg
|
||||
type: image_url
|
||||
role: user
|
||||
output: llama
|
||||
model:
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
test_chat_structured_output:
|
||||
test_name: test_chat_structured_output
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content: Extract the event information.
|
||||
role: system
|
||||
- content: Alice and Bob are going to a science fair on Friday.
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: calendar_event
|
||||
schema:
|
||||
properties:
|
||||
date:
|
||||
title: Date
|
||||
type: string
|
||||
name:
|
||||
title: Name
|
||||
type: string
|
||||
participants:
|
||||
items:
|
||||
type: string
|
||||
title: Participants
|
||||
type: array
|
||||
required:
|
||||
- name
|
||||
- date
|
||||
- participants
|
||||
title: CalendarEvent
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_calendar_event
|
||||
- input:
|
||||
messages:
|
||||
- content: You are a helpful math tutor. Guide the user through the solution
|
||||
step by step.
|
||||
role: system
|
||||
- content: how can I solve 8x + 7 = -23
|
||||
role: user
|
||||
response_format:
|
||||
json_schema:
|
||||
name: math_reasoning
|
||||
schema:
|
||||
$defs:
|
||||
Step:
|
||||
properties:
|
||||
explanation:
|
||||
title: Explanation
|
||||
type: string
|
||||
output:
|
||||
title: Output
|
||||
type: string
|
||||
required:
|
||||
- explanation
|
||||
- output
|
||||
title: Step
|
||||
type: object
|
||||
properties:
|
||||
final_answer:
|
||||
title: Final Answer
|
||||
type: string
|
||||
steps:
|
||||
items:
|
||||
$ref: '#/$defs/Step'
|
||||
title: Steps
|
||||
type: array
|
||||
required:
|
||||
- steps
|
||||
- final_answer
|
||||
title: MathReasoning
|
||||
type: object
|
||||
type: json_schema
|
||||
output: valid_math_reasoning
|
||||
model:
|
||||
- Llama-3.3-8B-Instruct
|
||||
- Llama-3.3-70B-Instruct
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
test_tool_calling:
|
||||
test_name: test_tool_calling
|
||||
test_params:
|
||||
input_output:
|
||||
- input:
|
||||
messages:
|
||||
- content: You are a helpful assistant that can use tools to get information.
|
||||
role: system
|
||||
- content: What's the weather like in San Francisco?
|
||||
role: user
|
||||
tools:
|
||||
- function:
|
||||
description: Get current temperature for a given location.
|
||||
name: get_weather
|
||||
parameters:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
location:
|
||||
description: "City and country e.g. Bogot\xE1, Colombia"
|
||||
type: string
|
||||
required:
|
||||
- location
|
||||
type: object
|
||||
type: function
|
||||
output: get_weather_tool_call
|
||||
model:
|
||||
- Llama-3.3-70B-Instruct
|
||||
- Llama-4-Scout-17B-16E
|
||||
- Llama-4-Scout-17B-16E-Instruct
|
||||
- Llama-4-Maverick-17B-128E
|
||||
- Llama-4-Maverick-17B-128E-Instruct
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
202
tests/verifications/openai/test_chat_completion.py
Normal file
202
tests/verifications/openai/test_chat_completion.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.verifications.openai.fixtures.load import load_test_cases
|
||||
|
||||
chat_completion_test_cases = load_test_cases("chat_completion")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def correct_model_name(model, provider, providers_model_mapping):
|
||||
"""Return the provider-specific model name based on the generic model name."""
|
||||
mapping = providers_model_mapping[provider]
|
||||
if model not in mapping:
|
||||
pytest.skip(f"Provider {provider} does not support model {model}")
|
||||
return mapping[model]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_basic"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_basic"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_basic(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=False,
|
||||
)
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert input_output["output"].lower() in response.choices[0].message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_basic"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_basic"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_streaming_basic(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=True,
|
||||
)
|
||||
content = ""
|
||||
for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
# TODO: add detailed type validation
|
||||
|
||||
assert input_output["output"].lower() in content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_image"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_image"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_image(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=False,
|
||||
)
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert input_output["output"].lower() in response.choices[0].message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", chat_completion_test_cases["test_chat_image"]["test_params"]["model"])
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_image"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_streaming_image(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
stream=True,
|
||||
)
|
||||
content = ""
|
||||
for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
# TODO: add detailed type validation
|
||||
|
||||
assert input_output["output"].lower() in content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["model"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_structured_output(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
response_format=input_output["input"]["response_format"],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
maybe_json_content = response.choices[0].message.content
|
||||
|
||||
validate_structured_output(maybe_json_content, input_output["output"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["model"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_chat_structured_output"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_streaming_structured_output(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
response_format=input_output["input"]["response_format"],
|
||||
stream=True,
|
||||
)
|
||||
maybe_json_content = ""
|
||||
for chunk in response:
|
||||
maybe_json_content += chunk.choices[0].delta.content or ""
|
||||
validate_structured_output(maybe_json_content, input_output["output"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["model"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"input_output",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["input_output"],
|
||||
)
|
||||
def test_chat_non_streaming_tool_calling(openai_client, input_output, correct_model_name):
|
||||
response = openai_client.chat.completions.create(
|
||||
model=correct_model_name,
|
||||
messages=input_output["input"]["messages"],
|
||||
tools=input_output["input"]["tools"],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert len(response.choices[0].message.tool_calls) > 0
|
||||
assert input_output["output"] == "get_weather_tool_call"
|
||||
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
||||
# TODO: add detailed type validation
|
||||
|
||||
|
||||
def get_structured_output(maybe_json_content: str, schema_name: str) -> Any | None:
|
||||
if schema_name == "valid_calendar_event":
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
participants: list[str]
|
||||
|
||||
try:
|
||||
calendar_event = CalendarEvent.model_validate_json(maybe_json_content)
|
||||
return calendar_event
|
||||
except Exception:
|
||||
return None
|
||||
elif schema_name == "valid_math_reasoning":
|
||||
|
||||
class Step(BaseModel):
|
||||
explanation: str
|
||||
output: str
|
||||
|
||||
class MathReasoning(BaseModel):
|
||||
steps: list[Step]
|
||||
final_answer: str
|
||||
|
||||
try:
|
||||
math_reasoning = MathReasoning.model_validate_json(maybe_json_content)
|
||||
return math_reasoning
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_structured_output(maybe_json_content: str, schema_name: str) -> None:
|
||||
structured_output = get_structured_output(maybe_json_content, schema_name)
|
||||
assert structured_output is not None
|
||||
if schema_name == "valid_calendar_event":
|
||||
assert structured_output.name is not None
|
||||
assert structured_output.date is not None
|
||||
assert len(structured_output.participants) == 2
|
||||
elif schema_name == "valid_math_reasoning":
|
||||
assert len(structured_output.final_answer) > 0
|
2744
tests/verifications/test_results/fireworks_1744154308.json
Normal file
2744
tests/verifications/test_results/fireworks_1744154308.json
Normal file
File diff suppressed because it is too large
Load diff
2672
tests/verifications/test_results/openai_1744154522.json
Normal file
2672
tests/verifications/test_results/openai_1744154522.json
Normal file
File diff suppressed because it is too large
Load diff
2830
tests/verifications/test_results/together_1744154399.json
Normal file
2830
tests/verifications/test_results/together_1744154399.json
Normal file
File diff suppressed because it is too large
Load diff
178
uv.lock
generated
178
uv.lock
generated
|
@ -1,4 +1,5 @@
|
|||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.10"
|
||||
resolution-markers = [
|
||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
|
@ -139,6 +140,22 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/7e/b3/6b4067be973ae96ba0d615946e314c5ae35f9f993eca561b356540bb0c2b/alabaster-1.0.0-py3-none-any.whl", hash = "sha256:fc6786402dc3fcb2de3cabd5fe455a2db534b371124f1f21de8731783dec828b", size = 13929 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "altair"
|
||||
version = "5.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "narwhals" },
|
||||
{ name = "packaging" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.14'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/16/b1/f2969c7bdb8ad8bbdda031687defdce2c19afba2aa2c8e1d2a17f78376d8/altair-5.5.0.tar.gz", hash = "sha256:d960ebe6178c56de3855a68c47b516be38640b73fb3b5111c2a9ca90546dd73d", size = 705305 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl", hash = "sha256:91a310b926508d560fe0148d02a194f38b824122641ef528113d029fcd129f8c", size = 731200 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
version = "0.7.0"
|
||||
|
@ -258,6 +275,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blinker"
|
||||
version = "1.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/21/28/9b3f50ce0e048515135495f198351908d99540d69bfdc8c1d15b73dc55ce/blinker-1.9.0.tar.gz", hash = "sha256:b4ce2265a7abece45e7cc896e98dbebe6cead56bcf805a3d23136d145f5445bf", size = 22460 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blobfile"
|
||||
version = "3.0.0"
|
||||
|
@ -282,6 +308,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/b3/58/a255894436f3eca4a20611785a30a43b85bc75adf1b77f227e1e6d0cce0a/braintrust_core-0.0.58-py3-none-any.whl", hash = "sha256:fa272b70376d2c6692acf00ebd9fb9bae057b0c53b2b6a59a64850bf79757311", size = 4438 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cachetools"
|
||||
version = "5.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.1.31"
|
||||
|
@ -783,6 +818,30 @@ http = [
|
|||
{ name = "aiohttp" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gitdb"
|
||||
version = "4.0.12"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "smmap" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gitpython"
|
||||
version = "3.1.44"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "gitdb" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/89/37df0b71473153574a5cdef8f242de422a0f5d26d7a9e231e6f169b4ad14/gitpython-3.1.44.tar.gz", hash = "sha256:c87e30b26253bf5418b01b0660f818967f3c503193838337fe5e573331249269", size = 214196 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "googleapis-common-protos"
|
||||
version = "1.67.0"
|
||||
|
@ -1386,6 +1445,12 @@ test = [
|
|||
{ name = "torchvision", version = "0.21.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
|
||||
{ name = "torchvision", version = "0.21.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
ui = [
|
||||
{ name = "llama-stack-client" },
|
||||
{ name = "pandas" },
|
||||
{ name = "streamlit" },
|
||||
{ name = "streamlit-option-menu" },
|
||||
]
|
||||
unit = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiosqlite" },
|
||||
|
@ -1416,6 +1481,7 @@ requires-dist = [
|
|||
{ name = "jinja2", marker = "extra == 'codegen'", specifier = ">=3.1.6" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "llama-stack-client", specifier = ">=0.2.1" },
|
||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.1" },
|
||||
{ name = "mcp", marker = "extra == 'test'" },
|
||||
{ name = "myst-parser", marker = "extra == 'docs'" },
|
||||
{ name = "nbval", marker = "extra == 'dev'" },
|
||||
|
@ -1423,6 +1489,7 @@ requires-dist = [
|
|||
{ name = "openai", marker = "extra == 'unit'" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" },
|
||||
{ name = "opentelemetry-sdk", marker = "extra == 'test'" },
|
||||
{ name = "pandas", marker = "extra == 'ui'" },
|
||||
{ name = "pillow" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'" },
|
||||
{ name = "prompt-toolkit" },
|
||||
|
@ -1452,6 +1519,8 @@ requires-dist = [
|
|||
{ name = "sphinxcontrib-redoc", marker = "extra == 'docs'" },
|
||||
{ name = "sphinxcontrib-video", marker = "extra == 'docs'" },
|
||||
{ name = "sqlite-vec", marker = "extra == 'unit'" },
|
||||
{ name = "streamlit", marker = "extra == 'ui'" },
|
||||
{ name = "streamlit-option-menu", marker = "extra == 'ui'" },
|
||||
{ name = "termcolor" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "tomli", marker = "extra == 'docs'" },
|
||||
|
@ -1461,6 +1530,7 @@ requires-dist = [
|
|||
{ name = "types-setuptools", marker = "extra == 'dev'" },
|
||||
{ name = "uvicorn", marker = "extra == 'dev'" },
|
||||
]
|
||||
provides-extras = ["dev", "unit", "test", "docs", "codegen", "ui"]
|
||||
|
||||
[[package]]
|
||||
name = "llama-stack-client"
|
||||
|
@ -1815,6 +1885,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/5f/df/76d0321c3797b54b60fef9ec3bd6f4cfd124b9e422182156a1dd418722cf/myst_parser-4.0.1-py3-none-any.whl", hash = "sha256:9134e88959ec3b5780aedf8a99680ea242869d012e8821db3126d427edc9c95d", size = 84579 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "narwhals"
|
||||
version = "1.34.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ec/1d/a21496389436e96394a6e3fb1a644d5bc382250baff76e867f0368a94068/narwhals-1.34.0.tar.gz", hash = "sha256:bdd3fa60bea1f1e8b698e483be18dd43af13290da12dba69ea16dc1f3edbb8f7", size = 265432 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/6d/875d5a7f8e14fc044ede74b94e739d7312c3c8d1a3878f649601b15fdd68/narwhals-1.34.0-py3-none-any.whl", hash = "sha256:9502b9aa5dfe125c090a3a0bbca95becfa1fac2cd67f8b80d12b1dc2ed751865", size = 325346 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nbformat"
|
||||
version = "5.10.4"
|
||||
|
@ -2571,6 +2650,19 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/0b/53/a64f03044927dc47aafe029c42a5b7aabc38dfb813475e0e1bf71c4a59d0/pydantic_settings-2.8.1-py3-none-any.whl", hash = "sha256:81942d5ac3d905f7f3ee1a70df5dfb62d5569c12f51a5a647defc1c3d9ee2e9c", size = 30839 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pydeck"
|
||||
version = "0.9.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2" },
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a1/ca/40e14e196864a0f61a92abb14d09b3d3da98f94ccb03b49cf51688140dab/pydeck-0.9.1.tar.gz", hash = "sha256:f74475ae637951d63f2ee58326757f8d4f9cd9f2a457cf42950715003e2cb605", size = 3832240 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/4c/b888e6cf58bd9db9c93f40d1c6be8283ff49d88919231afe93a6bcf61626/pydeck-0.9.1-py2.py3-none-any.whl", hash = "sha256:b3f75ba0d273fc917094fa61224f3f6076ca8752b93d46faf3bcfd9f9d59b038", size = 6900403 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.1"
|
||||
|
@ -3220,6 +3312,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smmap"
|
||||
version = "5.0.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.1"
|
||||
|
@ -3502,6 +3603,47 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/d9/61/f2b52e107b1fc8944b33ef56bf6ac4ebbe16d91b94d2b87ce013bf63fb84/starlette-0.45.3-py3-none-any.whl", hash = "sha256:dfb6d332576f136ec740296c7e8bb8c8a7125044e7c6da30744718880cdd059d", size = 71507 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "streamlit"
|
||||
version = "1.44.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "altair" },
|
||||
{ name = "blinker" },
|
||||
{ name = "cachetools" },
|
||||
{ name = "click" },
|
||||
{ name = "gitpython" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pandas" },
|
||||
{ name = "pillow" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "pyarrow" },
|
||||
{ name = "pydeck" },
|
||||
{ name = "requests" },
|
||||
{ name = "tenacity" },
|
||||
{ name = "toml" },
|
||||
{ name = "tornado" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "watchdog", marker = "sys_platform != 'darwin'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3e/c0/7286284567e5045f0c587c426d0c41aee5d10c0a2e360e627a83037e9f0c/streamlit-1.44.1.tar.gz", hash = "sha256:c6914ed6d5b76870b461510476806db370f36425ae0e6654d227c988288198d3", size = 9423685 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/17/fc425e1d4d86e31b2aaf0812a2ef2163763a0670d671720c7c36e8679323/streamlit-1.44.1-py3-none-any.whl", hash = "sha256:9fe355f58b11f4eb71e74f115ce1f38c4c9eaff2733e6bcffb510ac1298a5990", size = 9812242 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "streamlit-option-menu"
|
||||
version = "0.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "streamlit" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5e/27/72dc451cdaef1714fd0d75cc430e50a06c12c9046295fdf1f94af1b766eb/streamlit-option-menu-0.4.0.tar.gz", hash = "sha256:48ec69d59e547fa2fa4bfae001620df8af56a80de2f765ddbb9fcbfb84017129", size = 827290 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/52/2f525ad4262dc83d67297f69ec5afcee1438b9e9ae22aa318396725ddbed/streamlit_option_menu-0.4.0-py3-none-any.whl", hash = "sha256:a55fc7554047b6db371595af2182e435b8a2c715ee6124e8543685bd4670b07e", size = 829255 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sympy"
|
||||
version = "1.13.1"
|
||||
|
@ -3514,6 +3656,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "9.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0a/d4/2b0cd0fe285e14b36db076e78c93766ff1d529d70408bd1d2a5a84f1d929/tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb", size = 48036 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "2.5.0"
|
||||
|
@ -3559,6 +3710,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/de/a8/8f499c179ec900783ffe133e9aab10044481679bb9aad78436d239eee716/tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95", size = 894669 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.10.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/be/ba/1f744cdc819428fc6b5084ec34d9b30660f6f9daaf70eead706e3203ec3c/toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f", size = 22253 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.2.1"
|
||||
|
@ -3836,6 +3996,24 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/93/fa/849483d56773ae29740ae70043ad88e068f98a6401aa819b5d6bee604683/virtualenv-20.29.2-py3-none-any.whl", hash = "sha256:febddfc3d1ea571bdb1dc0f98d7b45d24def7428214d4fb73cc486c9568cce6a", size = 4301478 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "watchdog"
|
||||
version = "6.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/db/7d/7f3d619e951c88ed75c6037b246ddcf2d322812ee8ea189be89511721d54/watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282", size = 131220 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/c7/ca4bf3e518cb57a686b2feb4f55a1892fd9a3dd13f470fca14e00f80ea36/watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13", size = 79079 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/51/d46dc9332f9a647593c947b4b88e2381c8dfc0942d15b8edc0310fa4abb1/watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379", size = 79078 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/57/04edbf5e169cd318d5f07b4766fee38e825d64b6913ca157ca32d1a42267/watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e", size = 79076 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/cc/da8422b300e13cb187d2203f20b9253e91058aaf7db65b74142013478e66/watchdog-6.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:212ac9b8bf1161dc91bd09c048048a95ca3a4c4f5e5d4a7d1b1a7d5752a7f96f", size = 79077 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/3b/b8964e04ae1a025c44ba8e4291f86e97fac443bca31de8bd98d3263d2fcf/watchdog-6.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:e3df4cbb9a450c6d49318f6d14f4bbc80d763fa587ba46ec86f99f9e6876bb26", size = 79078 },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/ae/a696eb424bedff7407801c257d4b1afda455fe40821a2be430e173660e81/watchdog-6.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:2cce7cfc2008eb51feb6aab51251fd79b85d9894e98ba847408f662b3395ca3c", size = 79077 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:20ffe5b202af80ab4266dcd3e91aae72bf2da48c0d33bdb15c66658e685e94e2", size = 79078 },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/f6/d0e5b343768e8bcb4cda79f0f2f55051bf26177ecd5651f84c07567461cf/watchdog-6.0.0-py3-none-win32.whl", hash = "sha256:07df1fdd701c5d4c8e55ef6cf55b8f0120fe1aef7ef39a1c6fc6bc2e606d517a", size = 79065 },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/d9/c495884c6e548fce18a8f40568ff120bc3a4b7b99813081c8ac0c936fa64/watchdog-6.0.0-py3-none-win_amd64.whl", hash = "sha256:cbafb470cf848d93b5d013e2ecb245d4aa1c8fd0504e863ccefa32445359d680", size = 79070 },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "watchfiles"
|
||||
version = "1.0.4"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue